In Tensorflow, get the names of all the Tensors in a graph

前端 未结 10 695
独厮守ぢ
独厮守ぢ 2020-11-27 10:29

I am creating neural nets with Tensorflow and skflow; for some reason I want to get the values of some inner tensors for a given input, so I am usi

10条回答
  •  时光说笑
    2020-11-27 10:43

    The following solution works for me in TensorFlow 2.3 -

    def load_pb(path_to_pb):
        with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
            return graph
    
    tf_graph = load_pb(MODEL_FILE)
    sess = tf.compat.v1.Session(graph=tf_graph)
    
    # Show tensor names in graph
    for op in tf_graph.get_operations():
        print(op.values())
    

    where MODEL_FILE is the path to your frozen graph.

    Taken from here.

提交回复
热议问题