Tensorflow variable scope: reuse if variable exists

后端 未结 4 390
醉酒成梦
醉酒成梦 2020-12-02 15:17

I want a piece of code that creates a variable within a scope if it doesn\'t exist, and access the variable if it already exists. I need it to be the same c

4条回答
  •  伪装坚强ぢ
    2020-12-02 16:12

    A ValueError is raised in get_variable() when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:

    def get_scope_variable(scope_name, var, shape=None):
        with tf.variable_scope(scope_name) as scope:
            try:
                v = tf.get_variable(var, shape)
            except ValueError:
                scope.reuse_variables()
                v = tf.get_variable(var)
        return v
    
    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v')
    assert v1 == v2
    

    Note that the following also works:

    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v', [1])
    assert v1 == v2
    

    UPDATE. The new API supports auto-reusing now:

    def get_scope_variable(scope, var, shape=None):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            v = tf.get_variable(var, shape)
        return v
    

提交回复
热议问题