问题
By default, function dynamic_rnn outputs only hidden states (known as m) for each time point which can be obtained as follows:
cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
inputs=inputs,
sequence_length=sequence_lengths,
dtype=tf.float32)
Is there a way get intermediate (not final) cell states (c) in addition?
A tensorflow contributor mentions that it can be done with a cell wrapper:
class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, state)
output, next_state = self._inner_cell(input, state)
emit_output = (next_state, output)
return emit_output, next_state
However, it doesn't seem to work. Any ideas?
回答1:
The proposed solution works for me, but Layer.call method spec is more general, so the following Wrapper should be more robust to API changes. Thy this:
class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, *args, **kwargs):
output, next_state = self._inner_cell(input, *args, **kwargs)
emit_output = (next_state, output)
return emit_output, next_state
Here's the test:
n_steps = 2
n_inputs = 3
n_neurons = 5
X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)
X_batch = np.array([
# t = 0 t = 1
[[0, 1, 2], [9, 8, 7]], # instance 0
[[3, 4, 5], [0, 0, 0]], # instance 1
[[6, 7, 8], [6, 5, 4]], # instance 2
[[9, 0, 1], [3, 2, 1]], # instance 3
])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs_val = outputs[0].eval(feed_dict={X: X_batch})
print(outputs_val)
Returned outputs is the tuple of (?, 2, 10) and (?, 2, 5) tensors, which are all LSTM states and outputs. Note that I'm using the "graduated" version of LSTMCell, from tf.nn.rnn_cell package, not tf.contrib.rnn. Also note state_is_tuple=True to avoid dealing with LSTMStateTuple.
回答2:
Based on Maxim's idea, I ended up with the following solution:
class StatefulLSTMCell(LSTMCell):
def __init__(self, *args, **kwargs):
super(StatefulLSTMCell, self).__init__(*args, **kwargs)
@property
def output_size(self):
return (self.state_size, super(StatefulLSTMCell, self).output_size)
def call(self, input, state):
output, next_state = super(StatefulLSTMCell, self).call(input, state)
emit_output = (next_state, output)
return emit_output, next_state
来源:https://stackoverflow.com/questions/47745027/tensorflow-how-to-obtain-intermediate-cell-states-c-from-lstmcell-using-dynam