Tensorflow: how to obtain intermediate cell states (c) from LSTMCell using dynamic_rnn?

大兔子大兔子 提交于 2019-12-11 03:08:06

问题


By default, function dynamic_rnn outputs only hidden states (known as m) for each time point which can be obtained as follows:

cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                   inputs=inputs,
                                   sequence_length=sequence_lengths,
                                   dtype=tf.float32)

Is there a way get intermediate (not final) cell states (c) in addition?

A tensorflow contributor mentions that it can be done with a cell wrapper:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell
  @property
  def state_size(self):
     return self._inner_cell.state_size
  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)
  def call(self, input, state)
    output, next_state = self._inner_cell(input, state)
    emit_output = (next_state, output)
    return emit_output, next_state

However, it doesn't seem to work. Any ideas?


回答1:


The proposed solution works for me, but Layer.call method spec is more general, so the following Wrapper should be more robust to API changes. Thy this:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell

  @property
  def state_size(self):
     return self._inner_cell.state_size

  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)

  def call(self, input, *args, **kwargs):
    output, next_state = self._inner_cell(input, *args, **kwargs)
    emit_output = (next_state, output)
    return emit_output, next_state

Here's the test:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val = outputs[0].eval(feed_dict={X: X_batch})
  print(outputs_val)

Returned outputs is the tuple of (?, 2, 10) and (?, 2, 5) tensors, which are all LSTM states and outputs. Note that I'm using the "graduated" version of LSTMCell, from tf.nn.rnn_cell package, not tf.contrib.rnn. Also note state_is_tuple=True to avoid dealing with LSTMStateTuple.




回答2:


Based on Maxim's idea, I ended up with the following solution:

class StatefulLSTMCell(LSTMCell):
    def __init__(self, *args, **kwargs):
        super(StatefulLSTMCell, self).__init__(*args, **kwargs)

    @property
    def output_size(self):
        return (self.state_size, super(StatefulLSTMCell, self).output_size)

    def call(self, input, state):
        output, next_state = super(StatefulLSTMCell, self).call(input, state)
        emit_output = (next_state, output)
        return emit_output, next_state


来源:https://stackoverflow.com/questions/47745027/tensorflow-how-to-obtain-intermediate-cell-states-c-from-lstmcell-using-dynam

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!