[tensorflow] how to save trained network
** tensorflow에서 학습한 network 저장하기 **
** Following pseduo code will create 5 files at the location "net_dir"
1) checkpoint_state
2) input_graph.pb
3) saved_checkpoint-0.data-00000-of-00001
4) saved_checkpoint-0.index
5) saved_checkpoint-0.meta
/// PSEUDO CODE /////////////////////////////////////////////////////////////////////////
import tensorflow as tf
import something
# path where trained network to be saved
net_dir = 'path_to_save'
# define file name (!!do not change!!)
checkpoint_prefix = os.path.join(net_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# create saver
saver = tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
# initialize variables
sess.run(tf.global_variables_initializer())
# write graph first
tf.train.write_graph(sess.graph_def, net_dir, input_graph_name, as_text=False)
for epoch in range(num_epochs):
# # TRAINING
for step in range(num_batches_train):
X = sess.run([image_batch])
cost_value, _ = sess.run([cost, train_op], feed_dict={input: X})
cost_buffer.append(cost_value)
# # SAVE TRAINED NETWORK WHEN THE BEST PERFORMER APPEARS
if min_Cost > np.mean(cost_buffer):
min_Cost = np.mean(cost_buffer)
checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
# # WRITE AND CLEAR PARAMETERS
cost_buffer.clear()
sess.close()