问题
I am trying to use tf.while_loop
to run loops in parallel. However, in the following toy examples,loops don't appear to be running in parallel.
iteration = tf.constant(0)
c = lambda i: tf.less(i, 1000)
def print_fun(iteration):
print(f"This is iteration {iteration}")
iteration+=1
return (iteration,)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=10)
Or
i = tf.constant(0)
c = lambda i: tf.less(i, 1000)
b = lambda i: (tf.add(i, 1),)
r = tf.while_loop(c, b, [i])
What is preventing the tf.while_loop
from parallelizing the loop?
In addition, if anyone who maintain the Tensorflow documentation see this page, he/she should fix the bug in the first example. See the discussion here.
Thanks.
回答1:
parallel_iterations
doesn't mean anything when running in eager mode, but you can always use tf.function
decorator and gain significant speedups. This can be seen in this picture: running times
You can wrap your tf.while_loop
with tf.function
like this
@tf.function
def run_graph():
iteration = tf.constant(0)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=4)
and then call run_graph
when required.
来源:https://stackoverflow.com/questions/59299060/tf-2-0-while-loop-and-parallel-iterations