How to extract the cell state and hidden state from an RNN model in tensorflow?

。_饼干妹妹 提交于 2019-12-03 08:19:39

You may simply collect the values of the states in the same way accuracy is collected.

I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine.

One comment about your assumption: the "states" does have only the values of "hidden state" and "memory cell" from last timestep.

The "outputs" contain the "hidden state" from each time step you want (the size of outputs is [batch_size, seq_len, hidden_size]. So I assume that you want "outputs" variable, not "states". See the documentation.

I have to disagree with the answer of user3480922. For the code:

outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

to be able to extract the hidden state for each time_step in a prediction, you have to use the outputs. Because outputs have the hidden state value for each time_step. However, I am not sure is there any way we can store the values of the cell state for each time_step as well. Because states tuple provides the cell state values but only for the last time_step.

For example, in the following sample with 5 time_steps, the outputs[4,:,:], time_step = 0,...,4 has the hidden state values for time_step=4, whereas the states tuple h only has the hidden state values for time_step=4. State tuple c has the cell value at the time_step=4 though.

  outputs = [[[ 0.0589103 -0.06925126 -0.01531546 0.06108122]
  [ 0.00861215 0.06067181 0.03790079 -0.04296958]
  [ 0.00597713 0.03916606 0.02355802 -0.0277683 ]]

  [[ 0.06252582 -0.07336216 -0.01607122 0.05024602]
  [ 0.05464711 0.03219429 0.06635305 0.00753127]
  [ 0.05385715 0.01259535 0.0524035 0.01696803]]

  [[ 0.0853352 -0.06414541 0.02524283 0.05798233]
  [ 0.10790729 -0.05008117 0.03003334 0.07391824]
  [ 0.10205664 -0.04479517 0.03844892 0.0693808 ]]

  [[ 0.10556188 0.0516542 0.09162509 -0.02726674]
  [ 0.11425048 -0.00211394 0.06025286 0.03575509]
  [ 0.11338984 0.02839304 0.08105748 0.01564003]]

  **[[ 0.10072514 0.14767936 0.12387902 -0.07391471]
  [ 0.10510238 0.06321315 0.08100517 -0.00940042]
  [ 0.10553667 0.0984127 0.10094948 -0.02546882]]**]
  states = LSTMStateTuple(c=array([[ 0.23870754, 0.24315512, 0.20842518, -0.12798975],
  [ 0.23749796, 0.10797793, 0.14181322, -0.01695861],
  [ 0.2413336 , 0.16692916, 0.17559692, -0.0453596 ]], dtype=float32), h=array(**[[ 0.10072514, 0.14767936, 0.12387902, -0.07391471],
  [ 0.10510238, 0.06321315, 0.08100517, -0.00940042],
  [ 0.10553667, 0.0984127 , 0.10094948, -0.02546882]]**, dtype=float32))
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!