** tensorflow에서 학습한 network 저장하기 **
** Following pseduo code will create 5 files at the location "net_dir"
1) checkpoint_state
2) input_graph.pb
3) saved_checkpoint-0.data-00000-of-00001
4) saved_checkpoint-0.index
5) saved_checkpoint-0.meta
/// PSEUDO CODE /////////////////////////////////////////////////////////////////////////
import tensorflow as tf
import something
# path where trained network to be saved
net_dir = 'path_to_save'
# define file name (!!do not change!!)
checkpoint_prefix = os.path.join(net_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# create saver
saver = tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
# initialize variables
sess.run(tf.global_variables_initializer())
# write graph first
tf.train.write_graph(sess.graph_def, net_dir, input_graph_name, as_text=False)
for epoch in range(num_epochs):
# # TRAINING
for step in range(num_batches_train):
X = sess.run([image_batch])
cost_value, _ = sess.run([cost, train_op], feed_dict={input: X})
cost_buffer.append(cost_value)
# # SAVE TRAINED NETWORK WHEN THE BEST PERFORMER APPEARS
if min_Cost > np.mean(cost_buffer):
min_Cost = np.mean(cost_buffer)
checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
# # WRITE AND CLEAR PARAMETERS
cost_buffer.clear()
sess.close()
'Deep Learning' 카테고리의 다른 글
[tensorflow] How to make tfrecord file for training (0) | 2017.07.18 |
---|---|
[tensorflow] how to load and use the saved trained network in python (0) | 2017.07.18 |
[tensorflow] how to freeze trained network (make one pb file) (0) | 2017.07.18 |
[tensorflow] how to load and use CNN in C++ (0) | 2017.07.18 |
[tensorflow] How to install pycharm+anaconda+tensorflow(gpu) on window10 (0) | 2017.07.18 |