Tensorflow print all placeholder variable names from meta graph

前端 未结 2 735
面向向阳花
面向向阳花 2020-12-31 15:56

I have a tensorflow model for which I have the .meta and the checkpoint files. I am trying to print all the placeholders that the model requires, without looking at the co

2条回答
  •  无人及你
    2020-12-31 16:31

    mrry's answer is great. The second solution really helps. But the op name of the Placeholder changes in different TensorFlow versions. Here is my way to find out the correct placeholder op name in the Graphdef part of the .meta file:

    saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
    imported_graph = tf.get_default_graph()
    graph_op = imported_graph.get_operations()
    with open('output.txt', 'w') as f:
        for i in graph_op:
            f.write(str(i))
    

    In the output.txt file, we can easily find out the placeholder's correct op names and other attrs. Here is part of my output file:

    name: "input/input_image"
    op: "Placeholder"
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: -1
          }
          dim {
            size: 112
          }
          dim {
            size: 112
          }
          dim {
            size: 3
          }
        }
      }
    }
    

    Obviously, in my tensorflow version(1.6), the correct placeholder op name is Placeholder. Now return back to mrry's solution. Use [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"] to get a list of all the placeholder ops.

    Thus it's easy and convenient to perform the inference operation with only the ckpt files without needing to reconstruct the model. For example:

    input_x = ... # prepare the model input
    
    saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
    graph_x = tf.get_default_graph().get_tensor_by_name('input/input_image:0')
    graph_y = tf.get_default_graph().get_tensor_by_name('layer19/softmax:0')
    sess = tf.Session()
    saver.restore(sess, 'some_path/model.ckpt')
    
    output_y = sess.run(graph_y, feed_dict={graph_x: input_x})
    

提交回复
热议问题