In Tensorflow, how to use a restored meta-graph if the meta graph was feeding with TFRecord input (without placeholders)

后端 未结 3 1594
迷失自我
迷失自我 2020-12-17 05:37

I trained a network with TFRecord input pipeline. In other words, there was no placeholders. Simple example would be:

input, truth = _get_next_batch()  # TFR         


        
相关标签:
3条回答
  • 2020-12-17 06:17

    The recommended way is saving two meta graphs. One is for Training/Validation/Testing, and the other one is for inference.

    see Building a SavedModel

    export_dir = ...
    ...
    builder = tf.saved_model_builder.SavedModelBuilder(export_dir)
    with tf.Session(graph=tf.Graph()) as sess:
      ...
      builder.add_meta_graph_and_variables(sess,
                                           [tag_constants.TRAINING],
                                           signature_def_map=foo_signatures,
                                           assets_collection=foo_assets)
    ...
    # Add a second MetaGraphDef for inference.
    with tf.Session(graph=tf.Graph()) as sess:
      ...
      builder.add_meta_graph([tag_constants.SERVING])
    ...
    builder.save()
    

    The NMT tutorial also provides a detailed example about creating multiple graphs with shared variables: Neural Machine Translation (seq2seq) Tutorial-Building Training, Eval, and Inference Graphs

    train_graph = tf.Graph()
    eval_graph = tf.Graph()
    infer_graph = tf.Graph()
    
    with train_graph.as_default():
      train_iterator = ...
      train_model = BuildTrainModel(train_iterator)
      initializer = tf.global_variables_initializer()
    
    with eval_graph.as_default():
      eval_iterator = ...
      eval_model = BuildEvalModel(eval_iterator)
    
    with infer_graph.as_default():
      infer_iterator, infer_inputs = ...
      infer_model = BuildInferenceModel(infer_iterator)
    
    checkpoints_path = "/tmp/model/checkpoints"
    
    train_sess = tf.Session(graph=train_graph)
    eval_sess = tf.Session(graph=eval_graph)
    infer_sess = tf.Session(graph=infer_graph)
    
    train_sess.run(initializer)
    train_sess.run(train_iterator.initializer)
    
    for i in itertools.count():
    
      train_model.train(train_sess)
    
      if i % EVAL_STEPS == 0:
        checkpoint_path = train_model.saver.save(train_sess, checkpoints_path, global_step=i)
        eval_model.saver.restore(eval_sess, checkpoint_path)
        eval_sess.run(eval_iterator.initializer)
        while data_to_eval:
          eval_model.eval(eval_sess)
    
      if i % INFER_STEPS == 0:
        checkpoint_path = train_model.saver.save(train_sess, checkpoints_path, global_step=i)
        infer_model.saver.restore(infer_sess, checkpoint_path)
        infer_sess.run(infer_iterator.initializer, feed_dict={infer_inputs: infer_input_data})
        while data_to_infer:
          infer_model.infer(infer_sess)
    
    0 讨论(0)
  • 2020-12-17 06:24

    Is there a way to do it without placeholders at test though? It should be possible to re-use the graph with a new input pipeline without resorting to slow placeholders (i.e. the test dataset may be very large). placeholder_with_default is a suboptimal solution in that case.

    0 讨论(0)
  • 2020-12-17 06:27

    You can build a graph that uses placeholder_with_default() for the inputs, so can use both TFRecord input pipeline as well as feed_dict{}.

    An example:

    input, truth = _get_next_batch()
    _x = tf.placeholder_with_default(input, shape=[...], name='input')
    _y = tf.placeholder_with_default(truth, shape-[...], name='label')
    
    net = Model(_x)
    net.set_loss(_y)
    optimizer = tf...(net.loss)
    

    Then during inference,

    loaded_graph = tf.Graph()
    with tf.Session(graph=loaded_graph) as sess:
      new_saver = tf.train.import_meta_graph('ckpt-20000.meta')
      new_saver.restore(sess, 'ckpt-20000')
    
      # Get the tensors by their variable name
      input = loaded_graph.get_tensor_by_name('input:0')
      logits = loaded_graph.get_tensor_by_name(...)
    
      # Now you can feed the inputs to your tensors
      lgt = sess.run(logits, feed_dict = {input:img})
    

    In the above example, if you don't feed input, then the input will be read from the TFRecord input pipeline.

    0 讨论(0)
提交回复
热议问题