问题
I'm trying to use tf.case
(https://www.tensorflow.org/api_docs/python/tf/case) to conditionally update a Tensor. As shown, I'm trying to update learning_rate
to 0.01
when global_step == 2
, and to 0.001
when global_step == 4
.
However, when global_step == 2
, I already get learning_rate = 0.001
. Upon further inspection, it looks like tf.case
is giving me the wrong result when global_step == 2
(I get 0.001
instead of 0.01
). This is happening even though the predicate for 0.01
is evaluating to True, and the predicate for 0.001
is evaluating to False.
Am I doing something wrong, or is this a bug?
TF Version: 1.0.0
Code:
import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')
# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
pred = tf.equal(global_step, step)
fn_tensor = tf.constant(new_rate, dtype=tf.float32)
cases.append((pred, lambda: fn_tensor))
case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)
print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, update, updated_learning_rate])
sess.run(train_op)
Results:
1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
回答1:
This was answered in https://github.com/tensorflow/tensorflow/issues/8776
It turns out that tf.case
behavior is undefined if, in fn_tensors
, the lambdas return a tensor that was created outside of the lambda. The correct usage is to define the lambdas such that they return a newly-created tensor.
According to the linked Github issue, this usage is required because tf.case
must create the tensor itself in order to hook up the tensor's inputs to the correct branch of the predicate.
来源:https://stackoverflow.com/questions/42728235/tensorflow-why-is-tf-case-giving-me-the-wrong-result