Tensorflow 1.10 TFRecordDataset - recovering TFRecords

后端 未结 1 1458
栀梦
栀梦 2021-01-05 05:54

Notes:

  1. this question extends upon a previous question of mine. In that question I ask about the best way to store some dummy data as Example an

1条回答
  •  春和景丽
    2021-01-05 06:00

    Solved by updating the features to include shape information and remembering that SequenceExample are unnamed FeatureLists.

    context_features = {
        'Name' : tf.FixedLenFeature([], dtype=tf.string),
        'Val_1': tf.FixedLenFeature([], dtype=tf.float32),
        'Val_2': tf.FixedLenFeature([], dtype=tf.float32)
    }
    
    sequence_features = {
        'sequence': tf.FixedLenSequenceFeature((3,), dtype=tf.int64),
        'pclasses'  : tf.FixedLenSequenceFeature((3,), dtype=tf.float32),
    }
    
    def parse(record):
      parsed = tf.parse_single_sequence_example(
            record,
            context_features=context_features,
            sequence_features=sequence_features
      )
      return parsed
    
    
    filenames = [os.path.join(os.getcwd(),f"dummy_sequences_{i}.tfrecords") for i in range(3)]
    dataset = tf.data.TFRecordDataset(filenames).map(lambda r: parse(r))
    
    iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                               dataset.output_shapes)
    next_element = iterator.get_next()
    
    training_init_op = iterator.make_initializer(dataset)
    
    for _ in range(2):
      # Initialize an iterator over the training dataset.
      sess.run(training_init_op)
      for _ in range(3):
        ne = sess.run(next_element)
        print(ne)
    

    0 讨论(0)
提交回复
热议问题