How can you re-use a variable scope in tensorflow without a new scope being created by default?

半世苍凉 提交于 2019-12-01 04:13:29

Here is one straightforward way to do this using as with somename in a context manager. Using this somename.original_name_scope property, you can retrieve that scope and then add more variables to it. Below is an illustration:

In [6]: with tf.variable_scope('myscope') as ms1:
   ...:   tf.Variable(1.0, name='var1')
   ...: 
   ...: with tf.variable_scope(ms1.original_name_scope) as ms2:
   ...:   tf.Variable(2.0, name='var2')
   ...: 
   ...: print([n.name for n in tf.get_default_graph().as_graph_def().node])
   ...: 
['myscope/var1/initial_value', 
 'myscope/var1', 
 'myscope/var1/Assign', 
 'myscope/var1/read', 
 'myscope/var2/initial_value', 
 'myscope/var2', 
 'myscope/var2/Assign', 
 'myscope/var2/read']

Remark
Please also note that setting reuse=True is optional; That is, even if you pass reuse=True, you'd still get the same result.


Another way (thanks to OP himself!) is to just add / at the end of the variable scope when reusing it as in the following example:

In [13]: with tf.variable_scope('myscope'):
    ...:   tf.Variable(1.0, name='var1')
    ...: 
    ...: # reuse variable scope by appending `/` to the target variable scope
    ...: with tf.variable_scope('myscope/', reuse=True):
    ...:   tf.Variable(2.0, name='var2')
    ...: 
    ...: print([n.name for n in tf.get_default_graph().as_graph_def().node])
    ...: 
['myscope/var1/initial_value', 
 'myscope/var1', 
 'myscope/var1/Assign', 
 'myscope/var1/read', 
 'myscope/var2/initial_value', 
 'myscope/var2', 
 'myscope/var2/Assign', 
 'myscope/var2/read']

Remark:
Please note that setting reuse=True is again optional; That is, even if you pass reuse=True, you'd still get the same result.

Answer mentioned by kmario23 is correct but there is a tricky case with variables created by tf.get_variable:

with tf.variable_scope('myscope'):
    print(tf.get_variable('var1', shape=[3]))

with tf.variable_scope('myscope/'):
    print(tf.get_variable('var2', shape=[3]))

This snippet will output:

<tf.Variable 'myscope/var1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'myscope//var2:0' shape=(3,) dtype=float32_ref>

It seems that tensorflow has not provided a formal way to handle this circumstance yet. The only possible method I found is to manually assign the correct name (Warning: The correctness is not guaranteed):

with tf.variable_scope('myscope'):
    print(tf.get_variable('var1', shape=[3]))

with tf.variable_scope('myscope/') as scope:
    scope._name = 'myscope'
    print(tf.get_variable('var2', shape=[3]))

And then we can get the correct names:

<tf.Variable 'myscope/var1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'myscope/var2:0' shape=(3,) dtype=float32_ref>
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!