TensorFlow: getting variable by name

后端 未结 4 453
后悔当初
后悔当初 2020-12-08 09:36

When using the TensorFlow Python API, I created a variable (without specifying its name in the constructor), and its name property had the value

4条回答
  •  一向
    一向 (楼主)
    2020-12-08 10:12

    The get_variable() function creates a new variable or returns one created earlier by get_variable(). It won't return a variable created using tf.Variable(). Here's a quick example:

    >>> with tf.variable_scope("foo"):
    ...   bar1 = tf.get_variable("bar", (2,3)) # create
    ... 
    >>> with tf.variable_scope("foo", reuse=True):
    ...   bar2 = tf.get_variable("bar")  # reuse
    ... 
    
    >>> with tf.variable_scope("", reuse=True): # root variable scope
    ...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
    ... 
    >>> (bar1 is bar2) and (bar2 is bar3)
    True
    

    If you did not create the variable using tf.get_variable(), you have a couple options. First, you can use tf.global_variables() (as @mrry suggests):

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
    >>> bar1 is bar2
    True
    

    Or you can use tf.get_collection() like so:

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
    >>> bar1 is bar2
    True
    

    Edit

    You can also use get_tensor_by_name():

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> graph = tf.get_default_graph()
    >>> bar2 = graph.get_tensor_by_name("bar:0")
    >>> bar1 is bar2
    False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
    bar2 in value.
    

    Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0. If the operation has multiple outputs, they have the same name as the operation plus :0, :1, :2, and so on.

提交回复
热议问题