Tensorflow tf.cond evaluating both pedicate

匿名 (未验证) 提交于 2019-12-03 09:06:55

问题:

import tensorflow as tf import numpy as np  isTrain = tf.placeholder(tf.bool) user_input = tf.placeholder(tf.float32)  # ema = tf.train.ExponentialMovingAverage(decay=.5)  with tf.device('/cpu:0'):     beta = tf.Variable(tf.ones([1]))      batch_mean = beta.assign(user_input)     ema = tf.train.ExponentialMovingAverage(decay=0.5)     ema_apply_op = ema.apply([batch_mean])     ema_mean = ema.average(batch_mean)      def mean_var_with_update():         with tf.control_dependencies([ema_apply_op]):             return tf.identity(batch_mean)      mean = tf.cond(isTrain,         mean_var_with_update,         lambda: (ema_mean))  # ======= End Here ========== saver = tf.train.Saver() init = tf.initialize_all_variables() sess = tf.Session() sess.run(init)  u_input = [[2], [3], [4] ] for u in u_input:     aa = sess.run([mean], feed_dict={user_input:u, isTrain: True })     print("Train", aa)  for u in u_input:     aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False })     print("Test correct", aa)  for u in u_input:     aa = sess.run([mean], feed_dict={user_input:u, isTrain: False })     print("Test", aa) 

This code snippet should calculate the mean of user_input across training stage and output mean during testing stage.

This is the output result :

('Train', [array([ 2.], dtype=float32)]) ('Train', [array([ 3.], dtype=float32)]) ('Train', [array([ 4.], dtype=float32)]) ('Test correct', [array([ 3.], dtype=float32)]) ('Test correct', [array([ 3.], dtype=float32)]) ('Test correct', [array([ 3.], dtype=float32)]) ('Test', [array([ 2.5], dtype=float32)]) ('Test', [array([ 2.75], dtype=float32)]) ('Test', [array([ 3.375], dtype=float32)]) 

However, ema_mean always get evaluated when calling sess.run([mean]) even if isTrain = False.

Is there any mistake in the code ? tensorflow version is 0.7.1

回答1:

I think that is the same as answered here. The tf.control_dependencies inside the conditionals will add the dependencies to the tf.cond itself.

So try to create the ema_apply_op inside the mean_var_with_update function.



回答2:

I've added some logging statements and ema_mean only seems to get evaluated when isTrain is false

tf.reset_default_graph()  isTrain = tf.placeholder(tf.bool) user_input = tf.placeholder(tf.float32)  # ema = tf.train.ExponentialMovingAverage(decay=.5)  with tf.device('/cpu:0'):     beta = tf.Variable(tf.ones([1]))      batch_mean = beta.assign(user_input)     ema = tf.train.ExponentialMovingAverage(decay=0.5)     ema_apply_op = ema.apply([batch_mean])     ema_mean = ema.average(batch_mean)      def mean_var_with_update():         with tf.control_dependencies([ema_apply_op]):             return tf.Print(tf.identity(batch_mean), ["mean_var_with_update"])             #return tf.identity(batch_mean)      mean = tf.Print(tf.cond(isTrain,         mean_var_with_update,         lambda: (tf.Print(ema_mean, ["ema_mean"]))),                     ["evaluating mean", isTrain])  # ======= End Here ========== saver = tf.train.Saver() init = tf.initialize_all_variables() sess = tf.Session() sess.run(init)  u_input = [[2], [3], [4] ] for u in u_input:     aa = sess.run([mean], feed_dict={user_input:u, isTrain: True })     print("Train", aa)  for u in u_input:     aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False })     print("Test correct", aa)  for u in u_input:     aa = sess.run([mean], feed_dict={user_input:u, isTrain: False })     print("Test", aa) 

You see

[mean_var_with_update] [evaluating mean][True] [mean_var_with_update] [evaluating mean][True] [mean_var_with_update] [evaluating mean][True] [ema_mean] [evaluating mean][False] [ema_mean] [evaluating mean][False] [ema_mean] [evaluating mean][False] 

Note that the Print statement is evaluated after all the inputs have been evaluated so outer print statement is printed last



易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!