본문 바로가기

Deep Learning

[tensorflow] How to load a mini-batch from tfrecord and feed it to CNN

반응형

** tfrecord 파일로 부터 random mini-batch 만들고 CNN에 feed_dict 형식으로 입력 **


step1) load a mini-batch from tfrecord by tf.train.shuffle_batch()

step2) use tf.placeholder_with_default() to convert the tensor (of size mini_batch_size) to numpy array


// PSEUDO CODE ///////////////////////////////////

import tensorflow as tf

import os

import numpy as np

import network as net



def _get_image_and_label(param):


    ## directory where binary files are stored

    data_dir = param['data_dir']

    height = param['height']

    width = param['width']

    depth = param['depth']

    num_train_images = param['num_train_images']



    ## make filename Queue

    filenames = [os.path.join(data_dir, 'train_%d.tfrecords' % i) for i in range(1, 13)]

    filename_queue = tf.train.string_input_producer(filenames)



    ## open tfrecorder reader

    reader = tf.TFRecordReader()


    _, serialized_example = reader.read(filename_queue)


    features = tf.parse_single_example(serialized_example,

        features={

            'height': tf.FixedLenFeature([], tf.int64),

            'width': tf.FixedLenFeature([], tf.int64),

            'depth': tf.FixedLenFeature([], tf.int64),

            'label': tf.FixedLenFeature([], tf.int64),

            'image_raw': tf.FixedLenFeature([], tf.string)

        })


    # get label ([-1.0, 1.0]

    label = tf.cast(features['label'], tf.float32)


    # get image

    uint8image = tf.decode_raw(features['image_raw'], tf.uint8)

    image = tf.cast(tf.reshape(uint8image, tf.stack([height, width, depth])), tf.float32)


    return image, label



# # FOR TRAIN IMAGES

#  read a single image and its corresponding label from tfrecord

image, label = _get_image_and_label(learning_params)


# make mini-batch

image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=mini_batch_size, num_threads=8,

                                                      capacity=min_queue_examples_train + 3*mini_batch_size,

                                                      min_after_dequeue=min_queue_examples_train)


label_batch = tf.reshape(label_batch, [mini_batch_size, 1])


# convert the tensor to numpy arrary

image_batch = tf.placeholder_with_default(image_batch, shape=[mini_batch_size, height, width, depth])

label_batch = tf.placeholder_with_default(label_batch, shape=[mini_batch_size, 1])



# load CNN graph

cost, train_op = net.CNN()


with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())


    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)



    for epoch in range(num_epochs):


        # # TRAINING

        for step in range(num_batches_train):


            X, Y = sess.run([image_batch, label_batch])


            cost_value, _ = sess.run([cost, train_op], feed_dict={net.X: X, net.Y: Y})


coord.request_stop()

    coord.join(threads)


sess.close()