Restoring a Tensorflow model that uses Iterators

后端 未结 4 2074
遇见更好的自我
遇见更好的自我 2020-12-14 12:21

I have a model that\'s trains my network using an Iterator; following the new Dataset API pipeline model that\'s now recommended by Google.

I read tfrecord files, fe

相关标签:
4条回答
  • 2020-12-14 13:04

    When restoring a saved meta graph, you can restore the initialization operation with name and then use it again to initialize the input pipeline for inference.

    That is, when creating the graph, you can do

        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
    

    And then restore this operation by doing:

        dataset_init_op = graph.get_operation_by_name('dataset_init')
    

    Here is a self contained code snippet that compares results of a randomly initialized model before and after restoring.

    Saving an Iterator

    np.random.seed(42)
    data = np.random.random([4, 4])
    X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
    dataset = tf.data.Dataset.from_tensor_slices(X)
    iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
    dataset_next_op = iterator.get_next()
    
    # name the operation
    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
    
    w = np.random.random([1, 4])
    W = tf.Variable(w, name='W', dtype=tf.float32)
    output = tf.multiply(W, dataset_next_op, name='output')     
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    sess.run(dataset_init_op, feed_dict={X:data})
    while True:
        try:
            print(sess.run(output))
        except tf.errors.OutOfRangeError:
            saver.save(sess, 'tmp/', global_step=1002)
        break
    

    And then you can restore the same model for inference as follows:

    Restoring saved iterator

    np.random.seed(42)
    data = np.random.random([4, 4])
    tf.reset_default_graph()
    sess = tf.Session()
    saver = tf.train.import_meta_graph('tmp/-1002.meta')
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
    saver.restore(sess, ckpt.model_checkpoint_path)
    graph = tf.get_default_graph()
    
    # Restore the init operation
    dataset_init_op = graph.get_operation_by_name('dataset_init')
    
    X = graph.get_tensor_by_name('X:0')
    output = graph.get_tensor_by_name('output:0')
    sess.run(dataset_init_op, feed_dict={X:data})
    while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        break
    
    0 讨论(0)
  • I would suggest having a look at CheckpointInputPipelineHook CheckpointInputPipelineHook, which implements saving iterator state for further training with tf.Estimator.

    0 讨论(0)
  • 2020-12-14 13:13

    I would suggest to use tf.contrib.data.make_saveable_from_iterator, which has been designed precisely for this purpose. It is much less verbose and does not require you to change existing code, in particular how you define your iterator.

    Working example, when we save everything after step 5 has completed. Note how I don't even bother knowing what seed is used.

    import tensorflow as tf
    
    iterator = (
      tf.data.Dataset.range(100)
      .shuffle(10)
      .make_one_shot_iterator())
    batch = iterator.get_next(name='batch')
    
    saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
    tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
      tf.global_variables_initializer().run()
      for step in range(10):
        print('{}: {}'.format(step, sess.run(batch)))
        if step == 5:
          saver.save(sess, './foo', global_step=step)
    
    # 0: 1
    # 1: 6
    # 2: 7
    # 3: 3
    # 4: 8
    # 5: 10
    # 6: 12
    # 7: 14
    # 8: 5
    # 9: 17
    

    Then later, if we resume from step 6, we get the same output.

    import tensorflow as tf
    
    saver = tf.train.import_meta_graph('./foo-5.meta')
    with tf.Session() as sess:
      saver.restore(sess, './foo-5')
      for step in range(6, 10):
        print('{}: {}'.format(step, sess.run('batch:0')))
    # 6: 12
    # 7: 14
    # 8: 5
    # 9: 17
    
    0 讨论(0)
  • 2020-12-14 13:18

    I couldn't solve the problem related to initializing the iterator, but since I pre-process my dataset using map method, and I apply transformations defined by Python operations wrapped with py_func, which cannot be serialized for storing\restoring, I'll have to initialize my dataset when I want to restore it anyway.

    So, the problem that remains is how to feed data to my graph when I restore it. I placed a tf.identity node between the iterator output and my network input. Upon restoring, I feed my data to the identity node. A better solution that I discovered later is using placeholder_with_default(), as described in this answer.

    0 讨论(0)
提交回复
热议问题