1. 简介
LSTM模型作为一种经典的RNN网络结构,常用于NLP任务当中。在本篇工作中,我们进一步拓展了原始LSTM模型。注意到原始LSTM中输入x和之前状态h_prev是完全独立的,可能导致上下文信息的流失。我们提出一种形变LSTM,将输入x和之前状态h_prev进行交互,再输入进各个门里面运算。最后实验表明,改进过后的Mogrifier LSTM在各项任务均优于传统LSTM
2. 回顾传统LSTM
LSTM模型结构如下所示

它一共有4个门控系统,分别是遗忘门,输入门,候选记忆细胞,输出门
各个门的计算公式如下
-
遗忘门: -
输入门: -
候选记忆细胞: -
输出门: -
记忆细胞: -
新一轮的隐藏状态:
其中 代表的是sigmoid运算
各个门作用及机理如下
-
遗忘门: 主要控制是否遗忘上一层的记忆细胞状态, 输入分别是 当前时间步序列数据,上一时间步的隐藏状态,进行矩阵相乘,经sigmoid激活后,获得一个 值域在[0, 1]的输出F,再跟上一层记忆细胞进行对应元素相乘,输出F中越接近0,代表需要遗忘上层记忆细胞的元素。 -
候选记忆细胞:这里的区别在于将sigmoid函数换成tanh激活函数,因此输出的值域在[-1, 1]。 -
输入门:与遗忘门类似,也是经过sigmoid激活后,获得一个值域在[0, 1]的输出。 它用于控制当前输入X经过候选记忆细胞如何流入当前时间步的记忆细胞。 如果输入门输出接近为0,而遗忘门接近为1,则当前记忆细胞一直保存过去状态 -
输出门:也是通过sigmoid激活,获得一个值域在[0,1]的输出。主要控制记忆细胞到下一时间步隐藏状态的信息流动
相较于传统的RNN,LSTM引入了门机制,记忆细胞的设计使其能保存一定信息,在时间步进行传递,更好地捕捉时间序列较长的信息。而遗忘门的设计,更是能判断上一时刻信息是否对当前时刻产生影响,进而优化梯度流在整个网路的传递
但我们可以注意到,作为各个门的输入,X和隐藏状态H是完全独立的,这也是该研究的动机,如果输入前我能让X和隐藏状态H做交互,那性能是否能得到提升?
3. Mogrifier LSTM
Mogrifier LSTM引入以下两个公式


为了分别交互X 和 H,作者额外设置了两个矩阵Q,R
并且设定了一个超参数i,该参数分别控制X和H应该如何进行交互计算
当 ,整个模型就退化成原始的LSTM
最后乘以一个常数2,这是因为经过sigmoid运算后,其值分布在(0, 1),这样反复乘下去,值是会越来越小的。因此乘以一个2保证其数值的稳定性。
4. 实验
我们来简单看下实验结果

经过简单的改进,Mogrifier LSTM在各数据集上的表现均好于传统的LSTM
此外作者还探索了Mogrify中的超参数 设置,对模型性能的影响

文中也对Mogrify这种交互方式给了相应的示意图

5. 代码解析
作者也开源了相关代码在github上:https://github.com/RMichaelSwan/MogrifierLSTM
class MogLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, mog_iteration):
super(MogLSTM, self).__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
self.mog_iterations = mog_iteration
# 这里hiddensz乘4,是将四个门的张量运算都合并到一个矩阵当中,后续再通过张量分块给每个门
self.Wih = Parameter(torch.Tensor(input_sz, hidden_sz*4))
self.Whh = Parameter(torch.Tensor(hidden_sz, hidden_sz*4))
self.bih = Parameter(torch.Tensor(hidden_sz*4))
self.bhh = Parameter(torch.Tensor(hidden_sz*4))
# Mogrifiers
self.Q = Parameter(torch.Tensor(hidden_sz, input_sz))
self.R = Parameter(torch.Tensor(input_sz, hidden_sz))
self.init_weights()
def init_weights(self):
"""
权重初始化,对于W,Q,R使用xavier
对于偏置b则使用0初始化
:return:
"""
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def mogrify(self, xt, ht):
"""
计算mogrify
:param xt:
:param ht:
:return:
"""
for i in range(1, self.mog_iterations+1):
if(i % 2 == 0):
ht = (2*torch.sigmoid(xt @ self.R)*ht)
else:
xt = (2*torch.sigmoid(ht @ self.Q)*xt)
return xt, ht
def forward(self, x:torch.Tensor, init_states:Optional[Tuple[torch.Tensor, torch.Tensor]]=None) -> \
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
batch_sz, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
ht = torch.zeros((batch_sz, self.hidden_size)).to(x.device)
Ct = torch.zeros((batch_sz, self.hidden_size)).to(x.device)
else:
ht, Ct = init_states
for t in range(seq_sz):
xt = x[:, t, :]
xt, ht = self.mogrify(xt, ht)
gates = (xt @ self.Wih + self.bih) + (ht @ self.Whh + self.bhh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) # chunk方法将tensor分块
# LSTM
ft = torch.sigmoid(forgetgate)
it = torch.sigmoid(ingate)
Ct_candidate = torch.tanh(cellgate)
ot = torch.sigmoid(outgate)
# outputs
Ct = (ft*Ct) + (it*Ct_candidate)
ht = ot * torch.tanh(Ct)
hidden_seq.append(ht.unsqueeze(Dim.batch)) # unsqueeze是给指定位置加上维数为1的维度
hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
return hidden_seq, (ht, Ct)
-
首先输入分别表示 输入维度,隐层维度,Mogrify的计算次数(也就是前面提到的超参数i) -
然后分别初始化 权重Wih, Whh,Bih,Bhh。注意这里要乘4,这是因为LSTM里面有4个门,它将其合并为一个矩阵运算,最后再分配给4个门,提高速度 -
同样也是随机初始化用于Mogrify计算的两个矩阵Q ,R -
init_weights是进行参数初始化 -
方法mogrify里面,就是mogrify计算的部分了,根据计算次数设定一个for循环,根据奇偶性,分别对X和H进行交互计算,并返回 -
在forward前向计算中,先对隐层和输入X做初始化。然后进行矩阵运算,通过chunks方法将张量分成4部分,分别给四个门,再根据我们的前面的公式,分别进行sigmoid和tanh计算。然后更新细胞状态和隐藏状态。将隐藏状态连结成序列,最终返回隐藏状态序列,隐藏状态和细胞状态
6. 总结
本文的动机还是比较朴素的,从现有的LSTM缺陷出发,创新的引出mogrify计算方式,将原本互相独立的X和H进行了交互运算。并通过实验探讨了超参数的设置,最后的实验也表明,改造过后的Mogrifier LSTM相较于传统LSTM有着不小的提升
欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧
有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:

为了方便读者获取资料以及我们公众号的作者发布一些Github工程的更新,我们成立了一个QQ群,二维码如下,感兴趣可以加入。

本文分享自微信公众号 - GiantPandaCV(BBuf233)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。
来源:oschina
链接:https://my.oschina.net/u/4580321/blog/4428669