TensorFlow: Remember LSTM state for next batch (stateful LSTM)

后端 未结 2 1717
遇见更好的自我
遇见更好的自我 2020-11-28 04:48

Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (me

2条回答
  •  -上瘾入骨i
    2020-11-28 05:27

    Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

    with tf.variable_scope('decoder') as scope:
        rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
        ([
            tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
            tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
        ], state_is_tuple = True)
    
        state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]
    
        for t in range(TIME_STEPS):
            if t:
                last = y_[t - 1] if TRAINING else y[t - 1]
            else:
                last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))
    
            y[t] = tf.concat(1, (y[t], last))
            y[t], state = rnn_cell(y[t], state)
    
            scope.reuse_variables()
    

    Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.

提交回复
热议问题