Tensorflow: How to pass output from previous time-step as input to next timestep

旧巷老猫 提交于 2019-11-27 21:31:00

One way to do this is to write your own RNN cell, together with your own Multi-RNN cell. This way you can internally store the output of the last RNN cell and just access it in the next time step. Check this blogpost for more info. You can also add e.g. encoder or decoders directly in the cell, so that you can process the data before feeding it to the cell or after retrieving it from the cell.

Another possibility is to use the function tf.nn.raw_rnn which lets you control what happens before and after the calls to the RNN cells. The following code snippet shows how to use this function, credits go to this article.

from tensorflow.python.ops.rnn import _transpose_batch_time
import tensorflow as tf


def sampling_rnn(self, cell, initial_state, input_, seq_lengths):

    # raw_rnn expects time major inputs as TensorArrays
    max_time = ...  # this is the max time step per batch
    inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
    inputs_ta = inputs_ta.unstack(_transpose_batch_time(input_))  # model_input is the input placeholder
    input_dim = input_.get_shape()[-1].value  # the dimensionality of the input to each time step
    output_dim = ...  # the dimensionality of the model's output at each time step

        def loop_fn(time, cell_output, cell_state, loop_state):
            """
            Loop function that allows to control input to the rnn cell and manipulate cell outputs.
            :param time: current time step
            :param cell_output: output from previous time step or None if time == 0
            :param cell_state: cell state from previous time step
            :param loop_state: custom loop state to share information between different iterations of this loop fn
            :return: tuple consisting of
              elements_finished: tensor of size [bach_size] which is True for sequences that have reached their end,
                needed because of variable sequence size
              next_input: input to next time step
              next_cell_state: cell state forwarded to next time step
              emit_output: The first return argument of raw_rnn. This is not necessarily the output of the RNN cell,
                but could e.g. be the output of a dense layer attached to the rnn layer.
              next_loop_state: loop state forwarded to the next time step
            """
            if cell_output is None:
                # time == 0, used for initialization before first call to cell
                next_cell_state = initial_state
                # the emit_output in this case tells TF how future emits look
                emit_output = tf.zeros([output_dim])
            else:
                # t > 0, called right after call to cell, i.e. cell_output is the output from time t-1.
                # here you can do whatever ou want with cell_output before assigning it to emit_output.
                # In this case, we don't do anything
                next_cell_state = cell_state
                emit_output = cell_output  

            # check which elements are finished
            elements_finished = (time >= seq_lengths)
            finished = tf.reduce_all(elements_finished)

            # assemble cell input for upcoming time step
            current_output = emit_output if cell_output is not None else None
            input_original = inputs_ta.read(time)  # tensor of shape (None, input_dim)

            if current_output is None:
                # this is the initial step, i.e. there is no output from a previous time step, what we feed here
                # can highly depend on the data. In this case we just assign the actual input in the first time step.
                next_in = input_original
            else:
                # time > 0, so just use previous output as next input
                # here you could do fancier things, whatever you want to do before passing the data into the rnn cell
                # if here you were to pass input_original than you would get the normal behaviour of dynamic_rnn
                next_in = current_output

            next_input = tf.cond(finished,
                                 lambda: tf.zeros([self.batch_size, input_dim], dtype=tf.float32),  # copy through zeros
                                 lambda: next_in)  # if not finished, feed the previous output as next input

            # set shape manually, otherwise it is not defined for the last dimensions
            next_input.set_shape([None, input_dim])

            # loop state not used in this example
            next_loop_state = None
            return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)

    outputs_ta, last_state, _ = tf.nn.raw_rnn(cell, loop_fn)
    outputs = _transpose_batch_time(outputs_ta.stack())
    final_state = last_state

    return outputs, final_state

As a side note: It is not clear if relying on the model's outputs during training is a good idea. Especially in the beginning, the outputs of the model can be quite bad, so your training might never converge or might not learn anything meaningful.

Define a init_state together with your network layers:

init_state = tf.placeholder(tf.float32, [batch_size,hidden])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units = hidden)
state_series, current_state = tf.nn.dynamic_rnn(basic_cell, x, dtype=tf.float32, initial_state = init_state)

Then outside you training_steps_loop initialize the zero-state:

 _init_state = np.zeros([batch_size,hidden], dtype=np.float32)

Inside your training_steps_loop run the session and put _init_state in your feed_dict and make the returned _current_state to you new _init_state for the next step:

_training_op, _state_series, _current_state = sess.run(
                [training_op, state_series, current_state],  feed_dict={x: xdb, y: ydb, init_state:_init_state})

_init_state = _current_state

I think one tricky way is using tf.contrib.seq2seq.InferenceHelper because this helper can just pass the output state to the next-time-step input as this issue and this question discuss. Here is my own code(inspired by this question) that works:

"""
construct Decoder
"""
cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))

# should use a start token both training and inferring process
start_tokens = tf.tile(tf.constant([START_ARRAY], dtype=tf.float32), [BATCH_SIZE, 1], name='start_tokens')

# training decoder
with tf.variable_scope("decoder"):
    # below construct a helper that pass output to next timestep
    training_helper = tf.contrib.seq2seq.InferenceHelper(
        sample_fn=lambda outputs: outputs,
        sample_shape=[decoder_hidden_units],
        sample_dtype=tf.float32,
        start_inputs=start_tokens,
        end_fn=lambda sample_ids: False)

    training_decoder = tf.contrib.seq2seq.BasicDecoder(cell, training_helper,
                                                       initial_state=cell.zero_state(dtype=tf.float32,
                                                                                     batch_size=[BATCH_SIZE]).
                                                       clone(cell_state=encoder_state))

    training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                      impute_finished=True,
                                                                      maximum_iterations=max_iters)

And the predicting version of decoder is identical to this training decoder, you can inference directly.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!