参数
Input
- encoder_inputs:encoder的输入,int32型 id tensor list
- decoder_inputs:decoder的输入,int32型id tensor list
- cell: RNN_Cell的实例
- num_encoder_symbols, num_decoder_symbols: 分别是编码和解码的符号数,即词表大小
- embedding_size: 词向量的维度
- num_heads:attention的抽头数量,一个抽头算一种加权求和方式,后面会进一步介绍
- output_projection:decoder的output向量投影到词表空间时,用到的投影矩阵和偏置项(W, B);W的shape是[output_size, num_decoder_symbols],B的shape是[num_decoder_symbols];若此参数存在且feed_previous=True,上一个decoder的输出先乘W再加上B作为下一个decoder的输入
- feed_previous:若为True, 只有第一个decoder的输入(“GO"符号)有用,所有的decoder输入都依赖于上一步的输出;一般在测试时用(当然源码也提到,可以在训练时用于模拟测试的环境,比如Scheduled Sampling)
- initial_state_attention: 默认为False, 初始的attention是零;若为True,将从initial state和attention states开始attention
Output
- (outputs, state) tuple pair,outputs是 2D Tensors list, 每个Tensor的shape是[batch_size, cell.state_size];state是 最后一个时间步,decoder cell的state,shape是[batch_size, cell.state_size]
Encoder
- 创建了一个embedding matrix.
- 计算encoder的output和state
- 生成attention states,用于计算attention
文章来源: Attention Seq2Seq模型