How can I know the output and input tensor names in a saved model

我的梦境 提交于 2021-01-28 13:52:53

问题


I know how to load a saved TensorFlow model but how will I know the input and output tensor names.

I can load a protobuf file using tf.import_graph_def and then load the tensors using function get_tensor_by_name but how will I know the tensor names of any pre-trained model. Do I need to check their documentation or is there any other way.


回答1:


Assuming that the input and output tensors are placeholders, something like this should be helpful for you:

X = np.ones((1,3), dtype=np.float32)
tf.reset_default_graph()
model_saver = tf.train.Saver(defer_build=True)
input_pl = tf.placeholder(tf.float32, shape=[1,3], name="Input")
w = tf.Variable(tf.random_normal([3,3], stddev=0.01), name="Weight")
b = tf.Variable(tf.zeros([3]), name="Bias")
output = tf.add(tf.matmul(input_pl, w), b)
model_saver.build()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
model_saver.save(sess, "./model.ckpt")

Now, that the graph is built and saved, we can see the placeholder names like this:

model_loader = tf.train.Saver()
sess = tf.Session()
model_loader.restore(sess, "./model.ckpt")
placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# [<tf.Operation 'Input' type=Placeholder>]



回答2:


Solution only for inputs:

# read pb into graph_def
with tf.gfile.GFile(input_model_filepath, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# import graph_def
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def)

# print operations
for op in graph.get_operations():
    if op.type == "Placeholder":
        print(op.name)



回答3:


You can check the names and the input list for each operation in your graph to find the names of the tensors.

with tf.gfile.GFile(input_model_filepath, "rb") as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
  tf.import_graph_def(graph_def)

for op in graph.get_operations():
  print(op.name, [inp for inp in op.inputs])


来源:https://stackoverflow.com/questions/55313980/how-can-i-know-the-output-and-input-tensor-names-in-a-saved-model

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!