How to use Dataset API to read TFRecords file of lists of variant length?

后端 未结 2 939
离开以前
离开以前 2020-12-28 11:46

I want to use Tensorflow\'s Dataset API to read TFRecords file of lists of variant length. Here is my code.

def _int64_feature(value):
    # value must be          


        
相关标签:
2条回答
  • 2020-12-28 12:15

    After hours of searching and trying, I believe the answer emerges. Below is my code.

    def _int64_feature(value):
        # value must be a numpy array.
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))
    
    # Write an array to TFrecord.
    # a is an array which contains lists of variant length.
    a = np.array([[0, 54, 91, 153, 177],
                  [0, 50, 89, 147, 196],
                  [0, 38, 79, 157],
                  [0, 49, 89, 147, 177],
                  [0, 32, 73, 145]])
    
    writer = tf.python_io.TFRecordWriter('file')
    
    for i in range(a.shape[0]): # i = 0 ~ 4
        x_train = np.array(a[i])
        feature = {'i'   : _int64_feature(np.array([i])), 
                   'data': _int64_feature(x_train)}
    
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))
    
        # Serialize to string and write on the file
        writer.write(example.SerializeToString())
    
    writer.close()
    
    # Check TFRocord file.
    record_iterator = tf.python_io.tf_record_iterator(path='file')
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)
    
        i = (example.features.feature['i'].int64_list.value)
        data = (example.features.feature['data'].int64_list.value)
        print(i, data)
    
    # Use Dataset API to read the TFRecord file.
    filenames = ["file"]
    dataset = tf.data.TFRecordDataset(filenames)
    def _parse_function(example_proto):
        keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                            'data':tf.VarLenFeature(tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return tf.sparse_tensor_to_dense(parsed_features['i']), \
               tf.sparse_tensor_to_dense(parsed_features['data'])
    # Parse the record into tensors.
    dataset = dataset.map(_parse_function)
    # Shuffle the dataset
    dataset = dataset.shuffle(buffer_size=1)
    # Repeat the input indefinitly
    dataset = dataset.repeat()  
    # Generate batches
    dataset = dataset.batch(1)
    # Create a one-shot iterator
    iterator = dataset.make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(sess.run([i, data]))
        print(sess.run([i, data]))
        print(sess.run([i, data]))
    

    There are few things to note.
    1. This SO question helps a lot.
    2. tf.VarLenFeature would return SparseTensor, thus, using tf.sparse_tensor_to_dense to convert to dense tensor is necessary.
    3. In my code, parse_single_example() can't be replaced with parse_example(), and it bugs me for a day. I don't know why parse_example() doesn't work out. If anyone know the reason, please enlighten me.

    0 讨论(0)
  • 2020-12-28 12:35

    The error is very simple. Your data is not FixedLenFeature it is VarLenFeature. Replace your line:

     'data':tf.FixedLenFeature([], tf.int64)}
    

    with

     'data':tf.VarLenFeature(tf.int64)}
    

    Also, when you call print(i.eval()) and print(data.eval()) you are calling the iterator twice. The first print will print 0, but the second one will print the value of the second row [ 0, 50, 89, 147, 196]. You can do print(sess.run([i, data])) to get i and data from the same row.

    0 讨论(0)
提交回复
热议问题