How to use tf.data's initializable iterator and reinitializable interator and feed data to estimator api?

馋奶兔 提交于 2019-12-04 08:22:20

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook. This class then have access to the session used by the tf.estimator functions.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

This is a modified version from a code I found in a blogpost by Sebastian Pölsterl. Have a look under the "Feeding data to an Estimator via the Dataset API" section.

Or you can simply use tf.estimator.train_and_evaluate https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate It allows you to use validation during training without needing to care about iterator at all.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!