I want to use Tensorflow\'s Dataset API to read TFRecords file of lists of variant length. Here is my code.
def _int64_feature(value):
# value must be
After hours of searching and trying, I believe the answer emerges. Below is my code.
def _int64_feature(value):
# value must be a numpy array.
return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))
# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
[0, 50, 89, 147, 196],
[0, 38, 79, 157],
[0, 49, 89, 147, 177],
[0, 32, 73, 145]])
writer = tf.python_io.TFRecordWriter('file')
for i in range(a.shape[0]): # i = 0 ~ 4
x_train = np.array(a[i])
feature = {'i' : _int64_feature(np.array([i])),
'data': _int64_feature(x_train)}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
i = (example.features.feature['i'].int64_list.value)
data = (example.features.feature['data'].int64_list.value)
print(i, data)
# Use Dataset API to read the TFRecord file.
filenames = ["file"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
keys_to_features = {'i':tf.VarLenFeature(tf.int64),
'data':tf.VarLenFeature(tf.int64)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
return tf.sparse_tensor_to_dense(parsed_features['i']), \
tf.sparse_tensor_to_dense(parsed_features['data'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()
# Generate batches
dataset = dataset.batch(1)
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
print(sess.run([i, data]))
print(sess.run([i, data]))
print(sess.run([i, data]))
There are few things to note.
1. This SO question helps a lot.
2. tf.VarLenFeature
would return SparseTensor, thus, using tf.sparse_tensor_to_dense
to convert to dense tensor is necessary.
3. In my code, parse_single_example()
can't be replaced with parse_example()
, and it bugs me for a day. I don't know why parse_example()
doesn't work out. If anyone know the reason, please enlighten me.
The error is very simple. Your data
is not FixedLenFeature
it is VarLenFeature
. Replace your line:
'data':tf.FixedLenFeature([], tf.int64)}
with
'data':tf.VarLenFeature(tf.int64)}
Also, when you call print(i.eval())
and print(data.eval())
you are calling the iterator twice. The first print
will print 0
, but the second one will print the value of the second row [ 0, 50, 89, 147, 196]
. You can do print(sess.run([i, data]))
to get i
and data
from the same row.