Split .tfrecords file into many .tfrecords files

前端 未结 4 728
长情又很酷
长情又很酷 2020-12-09 20:37

Is there any way to split .tfrecords file into many .tfrecords files directly, without writing back each Dataset example ?

相关标签:
4条回答
  • 2020-12-09 21:03

    In tensorflow 2.0.0, this will work:

    import tensorflow as tf
    
    raw_dataset = tf.data.TFRecordDataset("input_file.tfrecord")
    
    shards = 10
    
    for i in range(shards):
        writer = tf.data.experimental.TFRecordWriter(f"output_file-part-{i}.tfrecord")
        writer.write(raw_dataset.shard(shards, i))
    
    0 讨论(0)
  • 2020-12-09 21:09

    You can use a function like this:

    import tensorflow as tf
    
    def split_tfrecord(tfrecord_path, split_size):
        with tf.Graph().as_default(), tf.Session() as sess:
            ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size)
            batch = ds.make_one_shot_iterator().get_next()
            part_num = 0
            while True:
                try:
                    records = sess.run(batch)
                    part_path = tfrecord_path + '.{:03d}'.format(part_num)
                    with tf.python_io.TFRecordWriter(part_path) as writer:
                        for record in records:
                            writer.write(record)
                    part_num += 1
                except tf.errors.OutOfRangeError: break
    

    For example, to split the file my_records.tfrecord into parts of 100 records each, you would do:

    split_tfrecord(my_records.tfrecord, 100)
    

    This would create multiple smaller record files my_records.tfrecord.000, my_records.tfrecord.001, etc.

    0 讨论(0)
  • 2020-12-09 21:23

    Very efficient way for TensorFlow 2.x

    As mentioned by @yongjieyongjie you should use .batch() instead of .shard() to avoid iterating more often over the dataset as needed. But in case you have a very large dataset, too big for memory, it will fail (but no error), just giving you a few files and a fraction of your original dataset.

    First you should batch your dataset, and use as batch size the amount of records you want to have per file (I assume your dataset is already in serialized format, otherwise see here).

    dataset = dataset.batch(ITEMS_PER_FILE)
    

    Next thing you want to do, is to use a generator to avoid running out of memory.

    def write_generator():
        i = 0
        iterator = iter(dataset)
        optional = iterator.get_next_as_optional()
        while optional.has_value().numpy():
            ds = optional.get_value()
            optional = iterator.get_next_as_optional()
            batch_ds = tf.data.Dataset.from_tensor_slices(ds)
            writer = tf.data.experimental.TFRecordWriter(save_to + "\\" + name + "-" + str(i) + ".tfrecord", compression_type='GZIP')#compression_type='GZIP'
            i += 1
            yield batch_ds, writer, i
        return
    

    Now simply use the generator in a normal for-loop

    for data, wri, i in write_generator():
        start_time = time.time()
        wri.write(data)
        print("Time needed: ", time.time() - start_time, "s", "\t", NAME_OF_FILES + "-" + str(i) + ".tfrecord")
    

    As long one single file fits raw in memory, this should just work fine.

    0 讨论(0)
  • 2020-12-09 21:27

    Using .batch() instead of .shard() to avoid iterating over dataset multiple times

    A more performant approach (compared to using tf.data.Dataset.shard()) would be to use batching:

    import tensorflow as tf
    
    ITEMS_PER_FILE = 100 # Assuming we are saving 100 items per .tfrecord file
    
    
    raw_dataset = tf.data.TFRecordDataset('in.tfrecord')
    
    batch_idx = 0
    for batch in raw_dataset.batch(ITEMS_PER_FILE):
    
        # Converting `batch` back into a `Dataset`, assuming batch is a `tuple` of `tensors`
        batch_ds = tf.data.Dataset.from_tensor_slices(tuple([*batch]))
        filename = f'out.tfrecord.{batch_idx:03d}'
    
        writer = tf.data.experimental.TFRecordWriter(filename)
        writer.write(batch_ds)
    
        batch_idx += 1
    
    0 讨论(0)
提交回复
热议问题