问题
I'm trying to use tf.case (https://www.tensorflow.org/api_docs/python/tf/case) to conditionally update a Tensor. As shown, I'm trying to update learning_rate to 0.01 when global_step == 2, and to 0.001 when global_step == 4.
However, when global_step == 2, I already get learning_rate = 0.001. Upon further inspection, it looks like tf.case is giving me the wrong result when global_step == 2 (I get 0.001 instead of 0.01). This is happening even though the predicate for 0.01 is evaluating to True, and the predicate for 0.001 is evaluating to False.
Am I doing something wrong, or is this a bug?
TF Version: 1.0.0
Code:
import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')
# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
pred = tf.equal(global_step, step)
fn_tensor = tf.constant(new_rate, dtype=tf.float32)
cases.append((pred, lambda: fn_tensor))
case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)
print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, update, updated_learning_rate])
sess.run(train_op)
Results:
1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
回答1:
This was answered in https://github.com/tensorflow/tensorflow/issues/8776
It turns out that tf.case behavior is undefined if, in fn_tensors, the lambdas return a tensor that was created outside of the lambda. The correct usage is to define the lambdas such that they return a newly-created tensor.
According to the linked Github issue, this usage is required because tf.case must create the tensor itself in order to hook up the tensor's inputs to the correct branch of the predicate.
来源:https://stackoverflow.com/questions/42728235/tensorflow-why-is-tf-case-giving-me-the-wrong-result