Is RNN initial state reset for subsequent mini-batches?

前端 未结 2 1796
时光取名叫无心
时光取名叫无心 2020-12-04 12:58

Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mention

2条回答
  •  無奈伤痛
    2020-12-04 13:35

    In addition to danijar's answer, here is the code for a LSTM, whose state is a tuple (state_is_tuple=True). It also supports multiple layers.

    We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

    def get_state_variables(batch_size, cell):
        # For each layer, get the initial state and make a variable out of it
        # to enable updating its value.
        state_variables = []
        for state_c, state_h in cell.zero_state(batch_size, tf.float32):
            state_variables.append(tf.contrib.rnn.LSTMStateTuple(
                tf.Variable(state_c, trainable=False),
                tf.Variable(state_h, trainable=False)))
        # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
        return tuple(state_variables)
    
    
    def get_state_update_op(state_variables, new_states):
        # Add an operation to update the train states with the last state tensors
        update_ops = []
        for state_variable, new_state in zip(state_variables, new_states):
            # Assign the new state to the state variables on this layer
            update_ops.extend([state_variable[0].assign(new_state[0]),
                               state_variable[1].assign(new_state[1])])
        # Return a tuple in order to combine all update_ops into a single operation.
        # The tuple's actual value should not be used.
        return tf.tuple(update_ops)
    

    Similar to danijar's answer, we can use that to update the LSTM's state after each batch:

    data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
    cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
    cell = tf.contrib.rnn.MultiRNNCell(cells)
    
    # For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
    states = get_state_variables(batch_size, cell)
    
    # Unroll the LSTM
    outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
    
    # Add an operation to update the train states with the last state tensors.
    update_op = get_state_update_op(states, new_states)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run([outputs, update_op], {data: ...})
    

    The main difference is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

提交回复
热议问题