How to speed up batch preparation when using Estimators API combined with tf.data.Dataset

柔情痞子 提交于 2020-06-26 06:11:08

问题


I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset.

My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.

I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn (the dataset.cache() doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).

Here is a simplified version of my code:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.


回答1:


Using tf.data.Dataset.cache() is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.

The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size elements. It is usually enough to have buffer_size = 1 at the end:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.


If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls argument of tf.data.Dataset.map().




回答2:


A few points to add to Olivier's answer, mostly from this post:

  • repeat before shuffle is slightly faster, at the downside of blurred epoch boundaries. This may be significant in rare cases, but I doubt it.
  • shuffle before mapping - this reduces the memory foot print of your shuffle buffer size, since it only needs to buffer the filenames rather than the file contents.
  • it makes more sense to me to apply the third map transform to the output of get_next() rather than the dataset - not sure if that affects speed much. You could also consider putting both other map calls in the same one to reduce scheduling issues.
  • experiment with repeat before batching. Probably won't make a difference, but might be minor. If you repeat before shuffle as mentioned above you'll have to.
  • as mentioned by Olivier, use prefetch.

Code with modifications:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels


来源:https://stackoverflow.com/questions/48068013/how-to-speed-up-batch-preparation-when-using-estimators-api-combined-with-tf-dat

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!