** float 타입 변수 tfrecord로 저장하기 **
Tensorflow.org recommends saving your training data in the format of tfrecord. Many examples can be found on websites. (See my previous post 2017/07/18 - [Deep Learning] - [tensorflow] How to make tfrecord file for training.) Unfortunately, however, most of the examples are dealing with saving images/labels of data type uint8, int64. In this post, I am going to deal with how to save training data of type float in the format of tfrecord. There are two possible ways!! One is standard, the other is trick.
Standard
To save float type training data in the format of tfrecord, you need to use tf.train.FloatList().
import tensorflow as tf
import numpy as np
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# data you would like to save, dtype=float32
data = np.empty(shape=(5, 1))
for k in range(5):
data[k] = ''' some real data '''
# open tfrecord file
writer = tf.python_io.TFRecordWriter(train_data_path)
# make train example
example = tf.train.Example(features=tf.train.Features(
feature={'data': _floats_feature(data)}))
# write on the file
writer.write(example.SerializeToString())
The important thing you should note here is that DO NOT USE BRACKET [] in tf.train.FloatList().
# wrong usage !! value (=data) is an nparray
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
To read data from the saved tfrecord,
# open tfrecorder reader
reader = tf.TFRecordReader()
# read file
_, serialized_example = reader.read(filename_queue)
# read data
features = tf.parse_single_example(serialized_example,
features={'data': tf.VarLenFeature(tf.float32)})
# make it dense tensor
data = tf.sparse_tensor_to_dense(features['data'], default_value=0)
# reshape
data = tf.reshape(data, [5])
return tf.train.batch(data, batch_size, num_threads, capacity)
Honestly, I don't know the exact reason why use tf.sparse_tensor_to_dense(). It seems that I see the reason somewhere in a blog.
Please someone let me know !!!
Trick
This method is a trick. You may lose data precision. To save,
import tensorflow as tf
import numpy as np
# you must use bracket [] since the data to be stored will be a string
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# data you would like to save, dtype=int64
data = np.empty(shape=(5, 1)).astype(int)
for k in range(5):# Real trick is here. -------------------------------------------
data[k] = int(some_value_of_type_float32 * 10000.0)
# open tfrecord file
writer = tf.python_io.TFRecordWriter(train_data_path)
# make train example. The data will be saved in the type of bytes after it is converted to a string.
example = tf.train.Example(features=tf.train.Features(
feature={'data': _bytes_feature(data.tostring())}))
# write on the file
writer.write(example.SerializeToString())
To read,
# open tfrecorder reader
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# make train example
features = tf.parse_single_example(serialized_example,
features={'data': tf.FixedLenFeature(tf.string)})
# decode
data = tf.decode_raw(features['data'], tf.int64)
# reshape and normalize it
data = tf.cast(tf.reshape(data, [5]), tf.float32) / 10000.0
return tf.train.batch(data, batch_size, num_threads, capacity)
Here is the key of the trick. First, convert a variable of type float32 to a variable of type int64 after multiplying the variable by 10.0^N.
data[k] = int(some_value_of_type_float32 * 10000.0)
N is a positive integer. If you want to minimize the data precision loss, set the value of N as large as possible. Next, save the variable using _byte_feature(data.tostring()). When you read the variable, read it by using tf.decode_raw() with data type int64 and normalizing it by 10.0^N after converting the type from int64 to float32.
'Deep Learning' 카테고리의 다른 글
[GTA5/GTAV] End-to-end learning for autonomous driving (0) | 2017.08.10 |
---|---|
[tensorflow] About using trained NN in C++ (0) | 2017.08.03 |
[tensorflow] How to read a bmp file and feed it into CNN (0) | 2017.07.18 |
[tensorflow] How to load a mini-batch from tfrecord and feed it to CNN (2) | 2017.07.18 |
[tensorflow] How to random flip an image and its corresponding label (0) | 2017.07.18 |