问题
I know how to load a saved TensorFlow model but how will I know the input and output tensor names.
I can load a protobuf file using tf.import_graph_def and then load the tensors using function get_tensor_by_name but how will I know the tensor names of any pre-trained model. Do I need to check their documentation or is there any other way.
回答1:
Assuming that the input and output tensors are placeholders, something like this should be helpful for you:
X = np.ones((1,3), dtype=np.float32)
tf.reset_default_graph()
model_saver = tf.train.Saver(defer_build=True)
input_pl = tf.placeholder(tf.float32, shape=[1,3], name="Input")
w = tf.Variable(tf.random_normal([3,3], stddev=0.01), name="Weight")
b = tf.Variable(tf.zeros([3]), name="Bias")
output = tf.add(tf.matmul(input_pl, w), b)
model_saver.build()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
model_saver.save(sess, "./model.ckpt")
Now, that the graph is built and saved, we can see the placeholder names like this:
model_loader = tf.train.Saver()
sess = tf.Session()
model_loader.restore(sess, "./model.ckpt")
placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# [<tf.Operation 'Input' type=Placeholder>]
回答2:
Solution only for inputs:
# read pb into graph_def
with tf.gfile.GFile(input_model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
if op.type == "Placeholder":
print(op.name)
回答3:
You can check the names and the input list for each operation in your graph to find the names of the tensors.
with tf.gfile.GFile(input_model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
for op in graph.get_operations():
print(op.name, [inp for inp in op.inputs])
来源:https://stackoverflow.com/questions/55313980/how-can-i-know-the-output-and-input-tensor-names-in-a-saved-model