Scaled Dot-Product Attention(transformer)
Scaled Dot-Product Attention是transformer的encoder的multi-head attention的组成部分。 由于Scaled Dot-Product Attention是multi-head的构成部分,因此Scaled Dot-Product Attention的数据的输入q,k,v的shape通常我们会变化为如下: (batch, n_head, seqLen, dim) 其中n_head表示multi-head的个数,且n_head*dim = embedSize 整个输入到输出,数据的维度保持不变。 temperature表示Scaled,即dim**0.5 mask表示每个batch对应样本中如果sequence为pad,则对应的mask为False,因此mask的初始维度为(batchSize, seqLen),为了计算,mask的维度会扩充为(batchSize, 1, 1, seqLen)。 class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature =