parallelising tf.data.Dataset.from_generator

后端 未结 3 812
孤独总比滥情好
孤独总比滥情好 2020-12-01 01:24

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from         


        
3条回答
  •  盖世英雄少女心
    2020-12-01 01:56

    I am working on a from_indexable for tf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448

    The advantage for from_indexable is that it can be parallelized, while a python generator cannot be parallelized.

    The function from_indexable makes a tf.data.range, wraps the indexable in a generalized tf.py_func and calls map.

    For those that want now a from_indexable, here the lib code

    import tensorflow as tf
    import numpy as np
    
    from tensorflow.python.framework import tensor_shape
    from tensorflow.python.util import nest
    
    def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
        def decorator(func):
            def call(*args):
                nonlocal output_shapes
    
                flat_output_types = nest.flatten(output_types)
                flat_values = tf.py_func(
                    func, 
                    inp=args, 
                    Tout=flat_output_types,
                    stateful=stateful, name=name
                )
                if output_shapes is not None:
                    # I am not sure if this is nessesary
                    output_shapes = nest.map_structure_up_to(
                        output_types, tensor_shape.as_shape, output_shapes)
                    flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                    for ret_t, shape in zip(flat_values, flattened_shapes):
                        ret_t.set_shape(shape)
                return nest.pack_sequence_as(output_types, flat_values)
            return call
        return decorator
    
    def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
        ds = tf.data.Dataset.range(len(iterator))
        @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
        def index_to_entry(index):
            return iterator[index]    
        return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
    

    and here an example (Note: from_indexable has a num_parallel_calls argument)

    class PyDataSet:
        def __len__(self):
            return 20
    
        def __getitem__(self, item):
            return np.random.normal(size=(item+1, 10))
    
    ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
    it = ds.make_one_shot_iterator()
    entry = it.get_next()
    with tf.Session() as sess:
        print(sess.run(entry).shape)
        print(sess.run(entry).shape)
    

    Update June 10, 2018: Since https://github.com/tensorflow/tensorflow/pull/15121 is merged, the code for from_indexable simplifies to:

    import tensorflow as tf
    
    def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
        def decorator(func):
            def call(*args, **kwargs):
                return tf.contrib.framework.py_func(
                    func=func, 
                    args=args, kwargs=kwargs, 
                    output_types=output_types, output_shapes=output_shapes, 
                    stateful=stateful, name=name
                )
            return call
        return decorator
    
    def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
        ds = tf.data.Dataset.range(len(iterator))
        @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
        def index_to_entry(index):
            return iterator[index]    
        return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
    

提交回复
热议问题