In Tensorflow, get the names of all the Tensors in a graph

前端 未结 10 678
独厮守ぢ
独厮守ぢ 2020-11-27 10:29

I am creating neural nets with Tensorflow and skflow; for some reason I want to get the values of some inner tensors for a given input, so I am usi

10条回答
  •  攒了一身酷
    2020-11-27 10:51

    I'll try to summarize the answers:

    To get all nodes: (type tensorflow.core.framework.node_def_pb2.NodeDef)

    all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]
    

    To get all ops: (type tensorflow.python.framework.ops.Operation)

    all_ops = tf.get_default_graph().get_operations()
    

    To get all variables: (type tensorflow.python.ops.resource_variable_ops.ResourceVariable)

    all_vars = tf.global_variables()
    

    To get all tensors: (type tensorflow.python.framework.ops.Tensor)

    all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]
    

    To get the graph in Tensorflow 2, instead of tf.get_default_graph() you need to instantiate a tf.function first and access the graph attribute, for example:

    graph = func.get_concrete_function().graph
    

    where func is a tf.function

提交回复
热议问题