A ValueError
is raised in get_variable()
when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:
def get_scope_variable(scope_name, var, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
v = tf.get_variable(var, shape)
except ValueError:
scope.reuse_variables()
v = tf.get_variable(var)
return v
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
Note that the following also works:
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
UPDATE. The new API supports auto-reusing now:
def get_scope_variable(scope, var, shape=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
v = tf.get_variable(var, shape)
return v
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…