tf.nn.dynamic_rnn的输出outputs和state含义

好久不见. 提交于 2019-12-16 21:53:45

tf.nn.dynamic_rnn的输出outputs和state含义

一、 tf.nn.dynamic_rnn的输出
tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

一、 tf.nn.dynamic_rnn的输出

tf.nn.dynamic_rnn的输入参数如下

  1. tf.nn.dynamic_rnn(
  2.     cell,
  3.     inputs,
  4.     sequence_length=None,
  5.     initial_state=None,
  6.     dtype=None,
  7.     parallel_iterations=None,
  8.     swap_memory=False,
  9.     time_major=False,
  10.     scope=None
  11. )

 

 tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

  • outputs. outputs是一个tensor
    <ul><li>如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size&nbsp;](要求rnn输入与rnn输出形状保持一致)</li>
    	<li>如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size&nbsp;]</li>
    </ul></li>
    <li><strong>state. </strong>state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size,&nbsp;cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size,&nbsp;cell.output_size ],其中2也对应着LSTM中的cell state和hidden state</li>
    

那为什么state输出形状会有变化呢?state和output又有什么关系呢?

二、state含义

对于第一问题“state”形状为什么会发生变化呢?

我们以LSTM和GRU分别为tf.nn.dynamic_rnn的输入cell类型为例,当cell为LSTM,state形状为[2,batch_size, cell.output_size ];当cell为GRU时,state形状为[batch_size, cell.output_size ]。其原因是因为LSTM和GRU的结构本身不同,如下面两个图所示,这是LSTM的cell结构,每个cell会有两个输出:

 和 

,上面这个图是输出

,代表哪些信息应该被记住哪些应该被遗忘; 下面这个图是输出

,代表这个cell的最终输出,LSTM的state是由

 和 

组成的。

当cell为GRU时,state就只有一个了,原因是GRU将

 和 

进行了简化,将其合并成了

,如下图所示,GRU将遗忘门和输入门合并成了更新门,另外cell不在有细胞状态cell state,只有hidden state。

对于第二个问题outputs和state有什么关系?

结论上来说,如果cell为LSTM,那 state是个tuple,分别代表

 和 

,其中 

与outputs中的对应的最后一个时刻的输出相等,假设state形状为[ 2,batch_size, cell.output_size ],outputs形状为 [ batch_size, max_time, cell.output_size ],那么state[ 1, batch_size, : ] == outputs[ batch_size, -1, : ];如果cell为GRU,那么同理,state其实就是 

,state ==outputs[ -1 ]

 

三、实验

我们写点代码来具体感觉下outputs和state是什么,代码如下

  1. import tensorflow as tf
  2. import numpy as np
  3. def dynamic_rnn(rnn_type='lstm'):
  4. # 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
  5. X = np.random.randn(3, 6, 4)
  6. # 第二个输入的实际长度为4
  7. X[1, 4:] = 0
  8. #记录三个输入的实际步长
  9. X_lengths = [6, 4, 6]
  10. rnn_hidden_size = 5
  11. if rnn_type == 'lstm':
  12. cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
  13. else:
  14. cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
  15. outputs, last_states = tf.nn.dynamic_rnn(
  16. cell=cell,
  17. dtype=tf.float64,
  18. sequence_length=X_lengths,
  19. inputs=X)
  20. with tf.Session() as session:
  21. session.run(tf.global_variables_initializer())
  22. o1, s1 = session.run([outputs, last_states])
  23. print(np.shape(o1))
  24. print(o1)
  25. print(np.shape(s1))
  26. print(s1)
  27. if __name__ == '__main__':
  28. dynamic_rnn(rnn_type='lstm')

实验一:cell类型为LSTM,我们看看输出是什么样子,如下图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 2, 3, 5 ],其中state第一部分为c,代表cell state;第二部分为h,代表hidden state。可以看到hidden state 与 对应的outputs的最后一行是相等的。另外需要注意的是输入一共有三个序列,但第二个序列的长度只有4,可以看到outputs中对应的两行值都为0,所以hidden state对应的是最后一个不为0的部分。tf.nn.dynamic_rnn通过设置sequence_length来实现这一逻辑。

  1. (3, 6, 5)
  2. [[[ 0.0146346 -0.04717453 -0.06930042 -0.06065602 0.02456717]
  3. [-0.05580321 0.08770171 -0.04574306 -0.01652854 -0.04319528]
  4. [ 0.09087799 0.03535907 -0.06974291 -0.03757408 -0.15553619]
  5. [ 0.10003044 0.10654698 0.21004055 0.13792148 -0.05587583]
  6. [ 0.13547596 -0.014292 -0.0211154 -0.10857875 0.04461256]
  7. [ 0.00417564 -0.01985144 0.00050634 -0.13238986 0.14323784]]
  8. [[ 0.04893576 0.14289175 0.17957205 0.09093887 -0.0507192 ]
  9. [ 0.17696126 0.09929577 0.21185635 0.20386451 0.11664373]
  10. [ 0.15658667 0.03952745 -0.03425637 0.00773833 -0.03546742]
  11. [-0.14002582 -0.18578786 -0.08373584 -0.25964601 0.04090167]
  12. [ 0. 0. 0. 0. 0. ]
  13. [ 0. 0. 0. 0. 0. ]]
  14. [[ 0.18564152 0.01531695 0.13752453 0.17188506 0.19555427]
  15. [ 0.13703949 0.14272294 0.21313036 0.07417354 0.0477547 ]
  16. [ 0.23021792 0.04455495 0.10204565 0.17159792 0.34148467]
  17. [ 0.0386402 0.0387848 0.02134559 0.00110381 0.08414687]
  18. [ 0.01386241 -0.02629686 -0.0733538 -0.03194245 0.13606553]
  19. [ 0.01859433 -0.00585316 -0.04007138 0.03811594 0.21708331]]]
  20. (2, 3, 5)
  21. LSTMStateTuple(c=array([[ 0.00909146, -0.03747076, 0.0008946 , -0.23459786, 0.29565899],
  22. [-0.18409266, -0.30463044, -0.28033809, -0.49032542, 0.12597639],
  23. [ 0.04494702, -0.01359631, -0.06706629, 0.06766361, 0.40794032]]), h=array([[ 0.00417564, -0.01985144, 0.00050634, -0.13238986, 0.14323784],
  24. [-0.14002582, -0.18578786, -0.08373584, -0.25964601, 0.04090167],
  25. [ 0.01859433, -0.00585316, -0.04007138, 0.03811594, 0.21708331]]))

实验二:cell类型为GRU,我们看看输出是什么样子,如下图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 3, 5 ]。可以看到 state 与 对应的outputs的最后一行是相等的。

  1. (3, 6, 5)
  2. [[[-0.05190962 -0.13519617 0.02045928 -0.0821183 0.28337528]
  3. [ 0.0201574 0.03779418 -0.05092804 0.02958051 0.12232347]
  4. [ 0.14884441 -0.26075898 0.1821795 -0.03454954 0.18424161]
  5. [-0.13854156 -0.26565378 0.09567164 -0.03960079 0.14000589]
  6. [-0.2605973 -0.39901657 0.12495693 -0.19295695 0.52423598]
  7. [-0.21596414 -0.63051687 0.20837501 -0.31775378 0.77519457]]
  8. [[-0.1979659 -0.30253523 0.0248779 -0.17981144 0.41815343]
  9. [ 0.34481129 -0.05256187 0.1643036 0.00739746 0.27384158]
  10. [ 0.49703664 0.22241165 0.27344766 0.00093435 0.09854949]
  11. [ 0.23312444 0.156997 0.25482553 0.0138156 -0.02302272]
  12. [ 0. 0. 0. 0. 0. ]
  13. [ 0. 0. 0. 0. 0. ]]
  14. [[-0.06401732 0.08605342 -0.03936866 -0.02287695 0.16947652]
  15. [-0.1775206 -0.2801672 -0.0387468 -0.20264583 0.58125297]
  16. [ 0.39408762 -0.44066425 0.25826641 -0.18851604 0.36172166]
  17. [ 0.0536013 -0.29902928 0.08891931 -0.03930039 0.0743423 ]
  18. [ 0.02304702 -0.0612499 0.09113458 -0.05169013 0.29876455]
  19. [-0.06711324 0.014125 -0.05856332 -0.05632359 -0.00390189]]]
  20. (3, 5)
  21. [[-0.21596414 -0.63051687 0.20837501 -0.31775378 0.77519457]
  22. [ 0.23312444 0.156997 0.25482553 0.0138156 -0.02302272]
  23. [-0.06711324 0.014125 -0.05856332 -0.05632359 -0.00390189]]

 

觉得原作者总结的很好,就转载过来了。

Reference

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

一、 tf.nn.dynamic_rnn的输出

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!