I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of eac
I am now saving the RNN states using the tf.control_dependencies. Here is an example.
saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
rnn_output, states = rnn(last_output, saved_states)
with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
dense_input = tf.concat(1, (last_output, rnn_output))
dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
last_output = dense_output + last_output
I just make sure that part of my graph is dependent on saving the state.