When using the tensorflow's Dataset API Iterator, my goal is to define an RNN that operates on the iterator's get_next()
tensors as its input (see (1)
in the code).
However, simply defining the dynamic_rnn
with get_next()
as its input results in an error: ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.
Now I know that one workaround is to simply create a placeholder for next_batch
and then eval()
the tensor (because you can't pass the tensor itself) and pass it using feed_dict
(see X
and (2)
in the code). However, if I understand it correctly, this is not an efficient solution as we first evaluate and then reinitialize the tensor.
Is there any way to either:
- Define the
dynamic_rnn
directly on top of the output of the Iterator;
or:
- Somehow directly pass the existing
get_next()
tensor to the placeholder that is the input ofdynamic_rnn
?
Full working example; the (1)
version is what I would like to work but it doesn't, while (2)
is the workaround that does work.
import tensorflow as tf from tensorflow.contrib.rnn import BasicLSTMCell from tensorflow.python.data import Iterator data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ] dataset = tf.data.Dataset.from_tensor_slices(data) dataset = dataset.batch(2) iterator = Iterator.from_structure(dataset.output_types, dataset.output_shapes) next_batch = iterator.get_next() iterator_init = iterator.make_initializer(dataset) # (2): X = tf.placeholder(tf.float32, shape=(None, 3, 1)) cell = BasicLSTMCell(num_units=8) # (1): # outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32) # (2): outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) sess.run(iterator_init) # (1): # o, s = sess.run([outputs, states]) # o, s = sess.run([outputs, states]) # (2): o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()}) o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
(Using tensorflow 1.4.0, Python 3.6.)
Thank you very much :)