본문 바로가기

Deep Learning

[tensorflow] How to make tfrecord file for training

반응형

** training data image들로 부터 tfrecord 만들기 **


** Assume you have saved 100 images at 'file_path'. The corresponding labels are stored in label.csv file. That is, one integer value label per RGB image. 


// PSEUDO CODE //////////////////////////////////////

import tensorflow as tf

import something



def _bytes_feature(value):

    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):

    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# define path

file_path = 'path_to_training_image'

label_path = 'path_to_label_file/label.csv'

tfrecord_path = 'path_to_tfrecord_save/test.tfrecords'


# open tfrecord file

writer_train = tf.python_io.TFRecordWriter(tfrecord_path_train)


# read csv files

reader = get_csv_reader(label_path, ",")


# random sampling

counter = 0



for row in reader:

    # read a label (int type)

    label = int(row[0])


    # read image

    file_name = parse_file_name(file_path, counter)

    img = io.imread(file_name)


    height = img.shape[0]

    width = img.shape[1]


    # convert to string

    img_raw = img.tostring()



    # write to tfrecord file

    example = tf.train.Example(features=tf.train.Features(feature={

        'height': _int64_feature(height),

        'width': _int64_feature(width),

        'depth': _int64_feature(3),

        'label': _int64_feature(label),

        'image_raw': _bytes_feature(img_raw)}))


    writer_train.write(example.SerializeToString())