Get the last output of a dynamic_rnn in TensorFlow

匿名 (未验证) 提交于 2019-12-03 01:26:01

问题:

I have a 3-D tensor of shape [batch, None, dim] where the second dimension, i.e. the timesteps, is unknown. I use dynamic_rnn to process such input, like in the following snippet:

import numpy as np import tensorflow as tf  batch = 2 dim = 3 hidden = 4  lengths = tf.placeholder(dtype=tf.int32, shape=[batch]) inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim]) cell = tf.nn.rnn_cell.GRUCell(hidden) cell_state = cell.zero_state(batch, tf.float32) output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state) 

Actually, running this snipped with some actual numbers, I have some reasonable results:

inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],                     [[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],                     dtype=np.int32) lengths_ = np.asarray([3, 1], dtype=np.int32)  with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     output_ = sess.run(output, {inputs: inputs_, lengths: lengths_})     print(output_) 

And the output is:

[[[ 0.          0.          0.          0.        ]   [ 0.02188676 -0.01294564  0.05340237 -0.47148666]   [ 0.0343586  -0.02243731  0.0870839  -0.89869428]   [ 0.          0.          0.          0.        ]]   [[ 0.00284752 -0.00315077  0.00108094 -0.99883419]   [ 0.          0.          0.          0.        ]   [ 0.          0.          0.          0.        ]   [ 0.          0.          0.          0.        ]]] 

Is there a way to get a 3-D tensor of shape [batch, 1, hidden] with the last relevant output of the dynamic RNN? Thanks!

回答1:

This is what gather_nd is for!

def extract_axis_1(data, ind):     """     Get specified elements along the first axis of tensor.     :param data: Tensorflow tensor that will be subsetted.     :param ind: Indices to take (one for each element along axis 0 of data).     :return: Subsetted tensor.     """      batch_range = tf.range(tf.shape(data)[0])     indices = tf.stack([batch_range, ind], axis=1)     res = tf.gather_nd(data, indices)      return res 

In your case:

output = extract_axis_1(output, lengths - 1) 

Now output is a tensor of dimension [batch_size, num_cells].



回答2:

From the following two sources,

http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/

outputs, last_states = tf.nn.dynamic_rnn( cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X) 

Or https://github.com/ageron/handson-ml/blob/master/14_recurrent_neural_networks.ipynb,

It is clear the last_states can be directly extracted from the SECOND output of the dynamic_rnn call. It will give you the last_states across all layers (in LSTM it is compsed from LSTMStateTuple) , while the outputs contains all the states in the last layer.



回答3:

Actually, the solution was not that hard. I implemented the following code:

slices = [] for index, l in enumerate(tf.unstack(lengths)):     slice = tf.slice(rnn_out, begin=[index, l - 1, 0], size=[1, 1, 3])     slices.append(slice) last = tf.concat(0, slices) 

So, the full snippet would be the following:

import numpy as np import tensorflow as tf  batch = 2 dim = 3 hidden = 4  lengths = tf.placeholder(dtype=tf.int32, shape=[batch]) inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim]) cell = tf.nn.rnn_cell.GRUCell(hidden) cell_state = cell.zero_state(batch, tf.float32) output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)  inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],                     [[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],                     dtype=np.int32) lengths_ = np.asarray([3, 1], dtype=np.int32)  slices = [] for index, l in enumerate(tf.unstack(lengths)):     slice = tf.slice(output, begin=[index, l - 1, 0], size=[1, 1, 3])     slices.append(slice) last = tf.concat(0, slices)  with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     outputs = sess.run([output, last], {inputs: inputs_, lengths: lengths_})     print 'RNN output:'     print(outputs[0])     print     print 'last relevant output:'     print(outputs[1]) 

And the output:

RNN output: [[[ 0.          0.          0.          0.        ]  [-0.06667092 -0.09284072  0.01098599 -0.03676109]  [-0.09101103 -0.19828682  0.03546784 -0.08721405]  [ 0.          0.          0.          0.        ]]  [[-0.00025157 -0.05704876  0.05527233 -0.03741353]  [ 0.          0.          0.          0.        ]  [ 0.          0.          0.          0.        ]  [ 0.          0.          0.          0.        ]]]  last relevant output: [[[-0.09101103 -0.19828682  0.03546784]]   [[-0.00025157 -0.05704876  0.05527233]]] 


回答4:

Okay ― so, looks like there actually is an easier solution. As @Shao Tang and @Rahul mentioned, the preferred way to do this would be by accessing the final cell state. Here’s why:

  • If you look at the GRUCell source code (below), you’ll see that the “state” that the cell maintains is actually the hidden weights themselves. So, when the tf.nn.dynamic_rnn returns the final state, it is actually returning the final hidden weights that you are interested in. To prove this, I just tweaked your setup and got the results:

GRUCell Call (rnn_cell_impl.py):

def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" if self._gate_linear is None:       bias_ones = self._bias_initializer if self._bias_initializer is None:         bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype) with vs.variable_scope("gates"):  # Reset gate and update gate. self._gate_linear = _Linear(             [inputs, state], 2 * self._num_units, True, bias_initializer=bias_ones, kernel_initializer=self._kernel_initializer)     value = math_ops.sigmoid(self._gate_linear([inputs, state]))     r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)     r_state = r * state if self._candidate_linear is None: with vs.variable_scope("candidate"): self._candidate_linear = _Linear(             [inputs, r_state], self._num_units, True, bias_initializer=self._bias_initializer, kernel_initializer=self._kernel_initializer)     c = self._activation(self._candidate_linear([inputs, r_state]))     new_h = u * state + (1 - u) * c return new_h, new_h 

Solution:

import numpy as np import tensorflow as tf  batch = 2 dim = 3 hidden = 4  lengths = tf.placeholder(dtype=tf.int32, shape=[batch]) inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim]) cell = tf.nn.rnn_cell.GRUCell(hidden) cell_state = cell.zero_state(batch, tf.float32) output, state = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)  inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],                     [[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],                     dtype=np.int32) lengths_ = np.asarray([3, 1], dtype=np.int32)  with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     output_, state_ = sess.run([output, state], {inputs: inputs_, lengths: lengths_})     print (output_)     print (state_) 

Output:

[[[ 0.          0.          0.          0.        ]   [-0.24305521 -0.15512943  0.06614969  0.16873555]   [-0.62767833 -0.30741733  0.14819752  0.44313088]   [ 0.          0.          0.          0.        ]]   [[-0.99152333 -0.1006391   0.28767768  0.76360202]   [ 0.          0.          0.          0.        ]   [ 0.          0.          0.          0.        ]   [ 0.          0.          0.          0.        ]]] [[-0.62767833 -0.30741733  0.14819752  0.44313088]  [-0.99152333 -0.1006391   0.28767768  0.76360202]] 
  • For other readers who are working with the LSTMCell (another popular option), things work a little differently. The LSTMCell maintains the state in a different way - cell state is either a tuple or a concatenated version of the actual cell state and the hidden state. So, to access the final hidden weights, you could set (is_state_tuple to True) during cell-initialization, and the final state will be a tuple : (final cell state, final hidden weights). So, in this case,

    _, (_, h) = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)

will give you the final weights.

References: c_state and m_state in Tensorflow LSTM https://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L308 https://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L415



易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!