How to add if condition in a TensorFlow graph?

后端 未结 2 1499
闹比i
闹比i 2020-12-07 13:18

Let\'s say I have following code:

x = tf.placeholder(\"float32\", shape=[None, ins_size**2*3], name = \"x_input\")
condition = tf.placeholder(\"int32\", shap         


        
2条回答
  •  萌比男神i
    2020-12-07 14:03

    You're correct that the if statement doesn't work here, because the condition is evaluated at graph construction time, whereas presumably you want the condition to depend on the value fed to the placeholder at runtime. (In fact, it will always take the first branch, because condition > 0 evaluates to a Tensor, which is "truthy" in Python.)

    To support conditional control flow, TensorFlow provides the tf.cond() operator, which evaluates one of two branches, depending on a boolean condition. To show you how to use it, I'll rewrite your program so that condition is a scalar tf.int32 value for simplicity:

    x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
    condition = tf.placeholder(tf.int32, shape=[], name="condition")
    W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
    b = tf.Variable(tf.zeros([label_option]), name="bias")
    
    y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
    

提交回复
热议问题