TensorFlow while-loop with TensorArray

假如想象 提交于 2019-12-05 15:10:56

问题


import tensorflow as tf

B = 3
D = 4
T = 5

tf.reset_default_graph()

xs = tf.placeholder(shape=[T, B, D], dtype=tf.float32)

with tf.variable_scope("RNN"):
    GRUcell = tf.contrib.rnn.GRUCell(num_units = D)
    cell = tf.contrib.rnn.MultiRNNCell([GRUcell]) 

    output_ta = tf.TensorArray(size=T, dtype=tf.float32)
    input_ta = tf.TensorArray(size=T, dtype=tf.float32)
    input_ta.unstack(xs)

    def body(time, output_ta_t, state):
        xt = input_ta.read(time)
        new_output, new_state = cell(xt, state)
        output_ta_t.write(time, new_output)
        return (time+1, output_ta_t, new_state)

    def condition(time, output, state):
        return time < T

    time = 0
    state = cell.zero_state(B, tf.float32)

    time_final, output_ta_final, state_final = tf.while_loop(
          cond=condition,
          body=body,
          loop_vars=(time, output_ta, state))

    output_final = output_ta_final.stack()

And I run it

x = np.random.normal(size=(T, B, D))
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    output_final_, state_final_ = sess.run(fetches = [output_final, state_final], feed_dict = {xs:x})

I would like to understand how to use TensorArray properly in relation with TensorFlow while loop. In the above sample I get the following error:

InvalidArgumentError: TensorArray RNN/TensorArray_1_21: Could not read from TensorArray index 0 because it has not yet been written to.

I do not understand this "could not read from TensorArray index 0". I think I write to the TensorArray input_ta by unstack and to output_ta in the while body. What do I do wrong? Thanks for your help.


回答1:


The solution is to change

input_ta.unstack(xs)

to

input_ta = input_ta.unstack(xs)

and similarly change

output_ta_t.write(time, new_output)

to

output_ta_t = output_ta_t.write(time, new_output)

With these two changes the code runs as expected.



来源:https://stackoverflow.com/questions/43701631/tensorflow-while-loop-with-tensorarray

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!