TensorFlow - Read all examples from a TFRecords at once?

前端 未结 6 1521
野的像风
野的像风 2020-12-13 02:29

How do you read all examples from a TFRecords at once?

I\'ve been using tf.parse_single_example to read out individual examples using code similar to th

6条回答
  •  旧巷少年郎
    2020-12-13 03:11

    You can also use tf.python_io.tf_record_iterator to manually iterate all examples in a TFRecord.

    I test that with an illustration code below:

    import tensorflow as tf
    
    X = [[1, 2],
         [3, 4],
         [5, 6]]
    
    
    def _int_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    
    def dump_tfrecord(data, out_file):
        writer = tf.python_io.TFRecordWriter(out_file)
        for x in data:
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    'x': _int_feature(x)
                })
            )
            writer.write(example.SerializeToString())
        writer.close()
    
    
    def load_tfrecord(file_name):
        features = {'x': tf.FixedLenFeature([2], tf.int64)}
        data = []
        for s_example in tf.python_io.tf_record_iterator(file_name):
            example = tf.parse_single_example(s_example, features=features)
            data.append(tf.expand_dims(example['x'], 0))
        return tf.concat(0, data)
    
    
    if __name__ == "__main__":
        dump_tfrecord(X, 'test_tfrecord')
        print('dump ok')
        data = load_tfrecord('test_tfrecord')
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            Y = sess.run([data])
            print(Y)
    

    Of course you have to use your own feature specification.

    The disadvantage is that I don't how to use multi-threads in this way. However, the most occasion we read all examples is when we evaluate validation data set, which is usually not very big. So I think the efficiency may be not a bottleneck.

    And I have another issue when I test this problem, which is that I have to specify the feature length. Instead of tf.FixedLenFeature([], tf.int64), I have to write tf.FixedLenFeature([2], tf.int64), otherwise, an InvalidArgumentError occured. I've no idea how to avoid this.

    Python: 3.4
    Tensorflow: 0.12.0

提交回复
热议问题