TF 2.0 while_loop and parallel_iterations

落花浮王杯 提交于 2020-01-16 12:03:10

问题


I am trying to use tf.while_loop to run loops in parallel. However, in the following toy examples,loops don't appear to be running in parallel.

iteration = tf.constant(0)
c = lambda i: tf.less(i, 1000)
def print_fun(iteration):
    print(f"This is iteration {iteration}")
    iteration+=1
    return (iteration,)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=10)

Or

i = tf.constant(0)
c = lambda i: tf.less(i, 1000)
b = lambda i: (tf.add(i, 1),)
r = tf.while_loop(c, b, [i])

What is preventing the tf.while_loop from parallelizing the loop?

In addition, if anyone who maintain the Tensorflow documentation see this page, he/she should fix the bug in the first example. See the discussion here.

Thanks.


回答1:


parallel_iterations doesn't mean anything when running in eager mode, but you can always use tf.function decorator and gain significant speedups. This can be seen in this picture: running times

You can wrap your tf.while_loop with tf.function like this

@tf.function
def run_graph():
    iteration = tf.constant(0)
    r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=4)

and then call run_graph when required.



来源:https://stackoverflow.com/questions/59299060/tf-2-0-while-loop-and-parallel-iterations

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!