Tensorflow print all placeholder variable names from meta graph

前端 未结 2 731
面向向阳花
面向向阳花 2020-12-31 15:56

I have a tensorflow model for which I have the .meta and the checkpoint files. I am trying to print all the placeholders that the model requires, without looking at the co

相关标签:
2条回答
  • 2020-12-31 16:28

    The tensors v1:0 and v2:0 were created from tf.placeholder() ops, whereas only tf.Variable objects are added to the "variables" (or "trainable_variables") collections. There is no general collection to which tf.placeholder() ops are added, so your options are:

    1. Add the tf.placeholder() ops to a collection (using tf.add_to_collection() when constructing the original graph. You might need to add more metadata in order to suggest how the placeholders should be used.

    2. Use [x for x in tf.get_default_graph().get_operations() if x.type == "PlaceholderV2"] to get a list of placeholder ops after you import the metagraph.

    0 讨论(0)
  • 2020-12-31 16:31

    mrry's answer is great. The second solution really helps. But the op name of the Placeholder changes in different TensorFlow versions. Here is my way to find out the correct placeholder op name in the Graphdef part of the .meta file:

    saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
    imported_graph = tf.get_default_graph()
    graph_op = imported_graph.get_operations()
    with open('output.txt', 'w') as f:
        for i in graph_op:
            f.write(str(i))
    

    In the output.txt file, we can easily find out the placeholder's correct op names and other attrs. Here is part of my output file:

    name: "input/input_image"
    op: "Placeholder"
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: -1
          }
          dim {
            size: 112
          }
          dim {
            size: 112
          }
          dim {
            size: 3
          }
        }
      }
    }
    

    Obviously, in my tensorflow version(1.6), the correct placeholder op name is Placeholder. Now return back to mrry's solution. Use [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"] to get a list of all the placeholder ops.

    Thus it's easy and convenient to perform the inference operation with only the ckpt files without needing to reconstruct the model. For example:

    input_x = ... # prepare the model input
    
    saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
    graph_x = tf.get_default_graph().get_tensor_by_name('input/input_image:0')
    graph_y = tf.get_default_graph().get_tensor_by_name('layer19/softmax:0')
    sess = tf.Session()
    saver.restore(sess, 'some_path/model.ckpt')
    
    output_y = sess.run(graph_y, feed_dict={graph_x: input_x})
    
    0 讨论(0)
提交回复
热议问题