Tensorflow: Why is tf.case giving me the wrong result?

一曲冷凌霜 提交于 2019-12-23 20:28:39

问题


I'm trying to use tf.case (https://www.tensorflow.org/api_docs/python/tf/case) to conditionally update a Tensor. As shown, I'm trying to update learning_rate to 0.01 when global_step == 2, and to 0.001 when global_step == 4.

However, when global_step == 2, I already get learning_rate = 0.001. Upon further inspection, it looks like tf.case is giving me the wrong result when global_step == 2 (I get 0.001 instead of 0.01). This is happening even though the predicate for 0.01 is evaluating to True, and the predicate for 0.001 is evaluating to False.

Am I doing something wrong, or is this a bug?

TF Version: 1.0.0

Code:

import tensorflow as tf

global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')

# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
    pred = tf.equal(global_step, step)
    fn_tensor = tf.constant(new_rate, dtype=tf.float32)
    cases.append((pred, lambda: fn_tensor))
    case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)

print tf.__version__
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in xrange(6):
        print sess.run([global_step, case_tensors, update, updated_learning_rate])
        sess.run(train_op)

Results:

1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]

回答1:


This was answered in https://github.com/tensorflow/tensorflow/issues/8776

It turns out that tf.case behavior is undefined if, in fn_tensors, the lambdas return a tensor that was created outside of the lambda. The correct usage is to define the lambdas such that they return a newly-created tensor.

According to the linked Github issue, this usage is required because tf.case must create the tensor itself in order to hook up the tensor's inputs to the correct branch of the predicate.



来源:https://stackoverflow.com/questions/42728235/tensorflow-why-is-tf-case-giving-me-the-wrong-result

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!