parallelising tf.data.Dataset.from_generator

后端 未结 3 815
孤独总比滥情好
孤独总比滥情好 2020-12-01 01:24

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from         


        
3条回答
  •  感情败类
    2020-12-01 02:09

    Turns out I can use Dataset.map if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map using a py_func.

    Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls to from_generator :)

    def pure_numpy_and_pil_complex_calculation(metadata, label):
      # some complex pil and numpy work nothing to do with tf
      ...
    
    dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                             output_types=(tf.string,   # metadata
                                                           tf.string))  # label
    
    def wrapped_complex_calulation(metadata, label):
      return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                        inp = (metadata, label),
                        Tout = (tf.uint8,    # (H,W,3) img
                                tf.string))  # label
    dataset = dataset.map(wrapped_complex_calulation,
                          num_parallel_calls=8)
    
    dataset = dataset.batch(64)
    iter = dataset.make_one_shot_iterator()
    imgs, labels = iter.get_next()
    

提交回复
热议问题