본문 바로가기

Deep Learning

[tensorflow] how to freeze trained network (make one pb file)

반응형

** 학습한 CNN graph 저장하기 - C++에서 사용하기 위해 **

** See '[tensorflow] how to save trained network' first! 

** Download the attached file ' freeze_graph.py'

** The following code will create graph.pb at net_dir. This pb file will be used at c++. 

** To see how to use graph.pb in c++, visit '[tensorflow] how to load and use CNN in c++' 



/// CODE /////////////////////////////

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function


import os

from TORCS_params import parse_args

from tensorflow.python.tools import freeze_graph



# saved trained-net

net_dir = 'path_to_saved_trained_network'

checkpoint_prefix = os.path.join(net_dir, "saved_checkpoint")

checkpoint_state_name = "checkpoint_state"

input_graph_name = "input_graph.pb"

output_graph_name = "graph.pb"


input_graph_path = os.path.join(net_dir, input_graph_name)

input_saver_def_path = ""

input_binary = True

input_checkpoint_path = os.path.join(net_dir, 'saved_checkpoint') + "-0"



# Note that we this normally should be only "output_node"!!!

output_node_names = "output_node"

restore_op_name = "save/restore_all"

filename_tensor_name = "save/Const:0"

output_graph_path = os.path.join(net_dir, output_graph_name)

clear_devices = False


freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,

                          input_binary, input_checkpoint_path, output_node_names,

                          restore_op_name, filename_tensor_name,

                          output_graph_path, clear_devices, "")







freeze_graph.py