tf.assign to variable slice doesn't work inside tf.while_loop

匿名 (未验证) 提交于 2019-12-03 01:00:01

问题:

What is wrong with the following code? The tf.assign op works just fine when applied to a slice of a tf.Variable if it happens outside of a loop. But, in this context, it gives the error below.

import tensorflow as tf  v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] n = len(v) a = tf.Variable(v, name = 'a')  def cond(i, a):     return i < n   def body(i, a):     tf.assign(a[i], a[i-1] + a[i-2])     return i + 1, a  i, b = tf.while_loop(cond, body, [2, a])  

results in:

Traceback (most recent call last):   File "<stdin>", line 1, in <module>   File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3210, in while_loop     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)   File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2942, in BuildLoop     pred, body, original_loop_vars, loop_vars, shape_invariants)   File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2879, in _BuildLoop     body_result = body(*packed_vars_for_body)   File "/home/hrbigelow/ai/lb-wavenet/while_var_test.py", line 11, in body     tf.assign(a[i], a[i-1] + a[i-2])   File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 220, in assign     return ref.assign(value, name=name)   File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 697, in assign     raise ValueError("Sliced assignment is only supported for variables") ValueError: Sliced assignment is only supported for variables 

回答1:

Your variable is not an output of the operations run inside your loop, it is an external entity living outside the loop. So you do not have to provide it as an argument.

Also, you need to enforce the update to take place, for example using tf.control_dependencies in body.

import tensorflow as tf  v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] n = len(v) a = tf.Variable(v, name = 'a')  def cond(i):     return i < n   def body(i):     op = tf.assign(a[i], a[i-1] + a[i-2])     with tf.control_dependencies([op]):       return i + 1  i = tf.while_loop(cond, body, [2])  sess = tf.InteractiveSession() tf.global_variables_initializer().run() i.eval() print(a.eval()) # [ 1  1  2  3  5  8 13 21 34 55 89] 

Possibly you may want to be cautious and set parallel_iterations=1 to enforce the loop to run sequentially.



回答2:

It makes sense from a CUDA perspective to disallow assignment of individual indices as it negates all performance benefits of heterogeneous parallel computing.

I know this adds a bit of computational overhead but it works.

import tensorflow as tf  v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] n = len(v) a = tf.Variable(v, name = 'a',dtype=tf.float32)  def cond(i, a):     return i < n   def body(i, a1):     e = tf.eye(n,n)[i]     a1 = a1 + e *(a1[i-1] + a1[i-2])     return i + 1, a1  i, b = tf.while_loop(cond, body, [2, a])   with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     print('i: ',sess.run(i))     print('b: ',sess.run(b)) 


回答3:

I was executing this a few times and it isn't consistent. But variable slices do work inside the while loop.

Tried to split the graph inside body because the results are incorrect sometimes.

The correct answer (11, array([ 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89])) is returned sometimes but not always.

import tensorflow as tf  v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] n = len(v) a1 = tf.Variable(v, name = 'a')  def cond(i, _):     return i < n  s = tf.InteractiveSession() s.run(tf.global_variables_initializer())  def body( i, _):     x = a1[i-1]     y = a1[i-2]     z = tf.add(x,y)     op = a1[i].assign( z )     with tf.control_dependencies([op]): #Edit This fixed the inconsistency.        increment = tf.add(i, 1)     return increment, op  print(s.run(tf.while_loop(cond, body, [2, a1]))) 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!