How to improve the performance of this data pipeline for my tensorflow model

后端 未结 2 986
野性不改
野性不改 2020-12-25 15:16

I have a tensorflow model which I am training on google-colab. The actual model is more complex, but I condensed it into a reproducible example (removed saving/restoring, le

2条回答
  •  长情又很酷
    2020-12-25 15:59

    The suggestion from hampi to profile your training job is a good one, and may be necessary to understand the actual bottlenecks in your pipeline. The other suggestions in the Input Pipeline performance guide should be useful as well.

    However, there is another possible "quick fix" that might be useful. In some cases, the amount of work in a Dataset.map() transformation can be very small, and dominated by the cost of invoking the function for each element. In those cases, we often try to vectorize the map function, and move it after the Dataset.batch() transformation, in order to invoke the function fewer times (1/512 as many times, in this case), and perform larger—and potentially easier-to-parallelize—operations on each batch. Fortunately, your pipeline can be vectorized as follows:

    def _batch_parser(record_batch):
      # NOTE: Use `tf.parse_example()` to operate on batches of records.
      parsed = tf.parse_example(record_batch, _keys_to_map)
      return parsed['d'], parsed['s']
    
    def init_tfrecord_dataset():
      files_train = glob.glob(DIR_TFRECORDS + '*.tfrecord')
      random.shuffle(files_train)
    
      with tf.name_scope('tfr_iterator'):
        ds = tf.data.TFRecordDataset(files_train)      # define data from randomly ordered files
        ds = ds.shuffle(buffer_size=10000)             # select elements randomly from the buffer
        # NOTE: Change begins here.
        ds = ds.batch(BATCH_SIZE, drop_remainder=True) # group elements in batch (remove batch of less than BATCH_SIZE)
        ds = ds.map(_batch_parser)                     # map batches based on tfrecord format
        # NOTE: Change ends here.
        ds = ds.repeat()                               # iterate infinitely 
    
        return ds.make_initializable_iterator()        # initialize the iterator
    

    Currently, vectorization is a change that you have to make manually, but the tf.data team are working on an optimization pass that provides automatic vectorization.

提交回复
热议问题