Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1
in the example below. After each timestep the internal LSTM (me
Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.
with tf.variable_scope('decoder') as scope:
rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
([
tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
], state_is_tuple = True)
state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]
for t in range(TIME_STEPS):
if t:
last = y_[t - 1] if TRAINING else y[t - 1]
else:
last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))
y[t] = tf.concat(1, (y[t], last))
y[t], state = rnn_cell(y[t], state)
scope.reuse_variables()
Rather than using tf.nn.rnn_cell.LSTMStateTuple
I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.