본문 바로가기

Deep Learning

[tensorflow] how to save trained network

반응형

** tensorflow에서 학습한 network 저장하기 **

** Following pseduo code will create 5 files at the location "net_dir"


1) checkpoint_state

2) input_graph.pb                                      

3) saved_checkpoint-0.data-00000-of-00001

4) saved_checkpoint-0.index

5) saved_checkpoint-0.meta                          


/// PSEUDO CODE /////////////////////////////////////////////////////////////////////////

import tensorflow as tf

import something



# path where trained network to be saved

net_dir = 'path_to_save'


# define file name (!!do not change!!)

checkpoint_prefix = os.path.join(net_dir, "saved_checkpoint")

checkpoint_state_name = "checkpoint_state"

input_graph_name = "input_graph.pb"

output_graph_name = "output_graph.pb"


# create saver

saver = tf.train.Saver(max_to_keep=1)

with tf.Session() as sess:


# initialize variables

    sess.run(tf.global_variables_initializer())


# write graph first

    tf.train.write_graph(sess.graph_def, net_dir, input_graph_name, as_text=False)


    for epoch in range(num_epochs):


        # # TRAINING

        for step in range(num_batches_train):


            X = sess.run([image_batch])

            cost_value, _ = sess.run([cost, train_op], feed_dict={input: X})

cost_buffer.append(cost_value)


        # # SAVE TRAINED NETWORK WHEN THE BEST PERFORMER APPEARS

        if min_Cost > np.mean(cost_buffer):

            min_Cost = np.mean(cost_buffer)

            checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)


        # # WRITE AND CLEAR PARAMETERS

        cost_buffer.clear()


sess.close()