问题
In tensorflow CIFAR-10 tutorial in cifar10_inputs.py line 174 it is said you should randomize the order of the operations random_contrast and random_brightness for better data augmentation.
To do so the first thing I think of is drawing a random variable from the uniform distribution between 0 and 1 : p_order. And do:
if p_order>0.5:
distorted_image=tf.image.random_contrast(image)
distorted_image=tf.image.random_brightness(distorted_image)
else:
distorted_image=tf.image.random_brightness(image)
distorted_image=tf.image.random_contrast(distorted_image)
However there are two possible options for getting p_order:
1) Using numpy which disatisfies me as I wanted pure TF and that TF discourages its user to mix numpy and tensorflow
2) Using TF, however as p_order can only be evaluated in a tf.Session() I do not really know if I should do:
with tf.Session() as sess2:
p_order_tensor=tf.random_uniform([1,],0.,1.)
p_order=float(p_order_tensor.eval())
All those operations are inside the body of a function and are run from another script which has a different session/graph. Or I could pass the graph from the other script as an argument to this function but I am confused. Even the fact that tensorflow functions like this one or inference for example seem to define the graph in a global fashion without explicitly returning it as an output is a bit hard to understand for me.
回答1:
You can use tf.cond(pred, fn1, fn2, name=None) (see doc).
This function allows you to use the boolean value of pred inside the TensorFlow graph (no need to call self.eval() or sess.run(), hence no need of a Session).
Here is an example of how to use it:
def fn1():
distorted_image=tf.image.random_contrast(image)
distorted_image=tf.image.random_brightness(distorted_image)
return distorted_image
def fn2():
distorted_image=tf.image.random_brightness(image)
distorted_image=tf.image.random_contrast(distorted_image)
return distorted_image
# Uniform variable in [0,1)
p_order = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
pred = tf.less(p_order, 0.5)
distorted_image = tf.cond(pred, fn1, fn2)
来源:https://stackoverflow.com/questions/37299345/using-if-conditions-inside-a-tensorflow-graph