TensorFlow: Remember LSTM state for next batch (stateful LSTM)

后端 未结 2 1720
遇见更好的自我
遇见更好的自我 2020-11-28 04:48

Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (me

2条回答
  •  栀梦
    栀梦 (楼主)
    2020-11-28 05:51

    I found out it was easiest to save the whole state for all layers in a placeholder.

    init_state = np.zeros((num_layers, 2, batch_size, state_size))
    
    ...
    
    state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
    

    Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

    l = tf.unpack(state_placeholder, axis=0)
    rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
     for idx in range(num_layers)]
    )
    

    RNN passes in the API:

    cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)
    outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)
    

    The state - variable will then be feeded to the next batch as a placeholder.

提交回复
热议问题