Generative Adversarial Nets[AAE]

我只是一个虾纸丫 提交于 2020-12-16 10:24:50

本文来自《Adversarial Autoencoders》,时间线为2015年11月。是大神Goodfellow的作品。本文还有些部分未能理解完全,不过代码在AAE_LabelInfo,这里实现了文中2.3小节,当然实现上有点差别,其中one-hot并不是11个类别,只是10个类别。

本文提出“对抗自动编码器(AAE)”,其本质上是自动编码器和GAN架构的合体,通过将AE隐藏层编码向量的聚合后验与任意先验分布进行匹配完成变分推论(variational inference)。将聚合后验与先验进行匹配确保从该先验任何部分都能够生成有意义的样本。AAE的解码层可以看成是一个深度生成模型,可以将强加的先验映射到数据分布上。本文并介绍如何将AAE用在如半监督分类,图像分类,无监督聚类,维度约间和数据可视化。 本文主要是介绍了几种AAE的应用:

  • Basic AAE (文中2到2.1之间的部分)
  • Incorporatiing Label Information in the Adversarial Regularization (文中2.3小节)
  • Supervised AAE (文中4小节)
  • Semi-supervised AAE (文中5小节)
  • Unsupervised Clustering with AAE (文中6小节)
  • Dimensionality Reduction with AAE (文中7小节)

0 引言

构建一个可伸缩的生成模型,能够抓取如语音,图像,视频等分布是ML中一个核心问题。近些年的模型如RBM,DBN,DBM都是通过MCMC算法进行训练的,而MCMC算法是通过计算log-似然的梯度去完成的,该方法在训练阶段并不实用,因为从马尔可夫链中采样的样本无法在模式之间快速混合。近些年,生成模型主要都是通过BP进行训练,避免MCMC的训练困难。如变分自动编码器(variational autoencoders,VAE)或者是重要性权重自动编码器(importance weighted autoencoders)都是采用一个识别网络(基于潜在变量基础上)去预测后验概率。GAN使用对抗训练过程直接塑造网络的输出,如生成时刻匹配网络(generative moment matching networks,GMMN)使用一个时刻匹配损失函数去学习数据的分布。

本文提出一个通用性算法,对抗自动编码器,将一个自动编码器变成生成模型。本文中AE通过两个目标函数进行训练,一个是传统的重构误差函数,另一个是对抗训练函数,意在将AE隐藏层向量表示的聚合后验分布与任意的先验分布进行匹配。可以发现这个训练准则和VAE有很强的联系。训练的结果是:

  • 编码器学到将数据分布转换成该先验分布;
  • 解码器学到一个深度生成模型,可以将强加的先验映射到数据分布上。
    '''update AE network''' 
    _, loss_likehood = sess.run([ae_optim, neg_marginal_likelihood], feed_dict=feed_dict_input)
    '''update discriminator network'''
    _, d_loss = sess.run([d_optim, D_loss], feed_dict=feed_dict_input)
    '''update generator network, run 2 times'''
    _, g_loss = sess.run([g_optim, G_loss], feed_dict=feed_dict_input)
    _, g_loss = sess.run([g_optim, G_loss], feed_dict=feed_dict_input)

1 对抗自动编码器(Adversarial Autoencoders,AAE)

假设$\mathbf{x}$是输入向量,$\mathbf{z}$是AE的隐藏层编码向量。令$p(\mathbf{z})$表示想要加在编码上的先验分布,$q(\mathbf{z}|\mathbf{x})$是一个编码分布,$p(\mathbf{x}|\mathbf{z})$是一个解码分布。同时,令$p_d(\mathbf{x})$表示数据分布,$p(\mathbf{x})$表示模型分布。基于AE的编码函数 $q(\mathbf{z}|\mathbf{x})$,定义$q(\mathbf{z})$的聚合后验分布如下: $$q(\mathbf{z})=\int_{\mathbf{x}}q(\mathbf{z}|\mathbf{x})p_d(\mathbf{x})d\mathbf{x} \tag{1}$$ AAE是在AE基础上,通过将聚合后验$q(\mathbf{z})$与任意先验$p(\mathbf{z})$进行匹配来完成正则化。为了完成这样的目标,对抗网络与AE的隐藏编码向量相关联,如图1.

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181227190200967-1197417573.png) ![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181227190042670-269711675.png)</center> 让对抗网络指导$q(\mathbf{z})$去匹配$p(\mathbf{z})$。同时该AE也尝试最小化重构误差。该对抗网络的生成器同时也是AE的编码器 $q(\mathbf{z}|\mathbf{x})$。该编码器确保聚合的后验分布可以愚弄对抗网络的判别器,让其误认为隐藏编码$q(\mathbf{z})$来自真实先验分布$p(\mathbf{z})$。

对抗网络和AE是通过SGD基于两个阶段联合训练的:基于mini-batch执行重构阶段和正则阶段

  • 在重构阶段,AE更新编码器和解码器,并最小化输入的重构误差;
  • 在正则阶段,对抗网络首先更新判别网络,以区分真实样本(使用先验生成)和生成样本(通过AE计算隐藏编码);然后,对抗网络更新生成器(AE的编码器)去混乱判别器。

一旦训练阶段完成了,AE的解码器就可以看成是一个生成模型,可以将强加的先验$p(\mathbf{z})$映射回数据分布上。 关于AE的编码器 $q(\mathbf{z}|\mathbf{x})$,有几种可能的选择:

确定性(Deterministic) 假设$q(\mathbf{z}|\mathbf{x})$是$\mathbf{x}$的确定性函数。这种情况下,编码器就类似标准AE的编码器,$q(\mathbf{z})$中的随机来源就是数据分布$q_d(\mathbf{x})$。

高斯后验(Gaussian posterior) 假设$q(\mathbf{z}|\mathbf{x})$是一个高斯分布,其均值和方差是通过编码网络预测的:$z_i\sim \mathcal{N}(\mu_i(\mathbf{x}),\sigma_i(\mathbf{x}))$。在这种情况下,$q(\mathbf{z})$中的随机性同时来自数据分布和编码器输出的高斯分布随机性。在网络的BP过程中,可以使用《Auto-encoding variational bayes》同样的重新参数技巧。

通用近似后验(Universal approximator posterior) AAE可以训练$q(\mathbf{z}|\mathbf{x})$成通用近似后验。假设AAE的编码网络是函数$f(\mathbf{x},\eta )$,其输入是$\mathbf{x}$和一个固定分布(如高斯)的随机噪音$\eta$。通过在$\eta$的不同样本上评估$f(\mathbf{x},\eta )$,从而从任意的后验分布$q(\mathbf{z}|\mathbf{x})$中进行采样。换句话说,假设$q(\mathbf{z}|\mathbf{x},\eta)=\delta (\mathbf{z}-f(\mathbf{x},\eta))$,那么后验$q(\mathbf{z}|\mathbf{x})$和聚合后验$q(\mathbf{z})$定义如下: $$q(\mathbf{z}|\mathbf{x})=\int_{\eta}q(\mathbf{z}|\mathbf{x},\eta)p_{\eta}(\eta)d\eta\Rightarrow q(\mathbf{z})=\int_{\mathbf{x}}\int_{\eta}q(\mathbf{z}|\mathbf{x},\eta)p_d(\mathbf{x})p_{\eta}(\eta)d\eta d\mathbf{x} $$ 在该情况下,$q(\mathbf{z})$的随机性同时来自数据分布和编码器输入上的随机噪音$\eta$。注意到该情况中后验分布$q(\mathbf{z}|\mathbf{x})$不再是受限于高斯,且编码器可以基于给定输入$\mathbf{x}$学到任意的后验分布。因为从聚合后验$q(\mathbf{z})$上采样是一个高效的方法,对抗训练过程可以通过在编码网络$f(\mathbf{x},\eta)$上进行BP,让$q(\mathbf{z})$去匹配$p(\mathbf{z})$。

从上述三种策略,选择不同类型的$q(\mathbf{z}|\mathbf{x})$可以生成不同类型的模型。例如,在$q(\mathbf{z}|\mathbf{x})$的确定情况中,网络只能让$q(\mathbf{z})$去匹配$p(\mathbf{z})$,此时只利用了数据分布的随机性。但是因为数据的经验性分布是被训练集固定的,映射是确定的,这可能生成一个不是很平滑的$q(\mathbf{z})$;然而,在高斯或者通用近似情况中,网络需要额外的随机性来源,以帮助在对抗正则阶段中对$q(\mathbf{z})$进行平滑惩罚。 然而,在多次试验后,作者发现每个$q(\mathbf{z}|\mathbf{x})$策略上得到结果大同小异。所以在剩下部分中,只介绍$q(\mathbf{z}|\mathbf{x})$的确定性策略。

1.1 与VAE的关系

本文的想法类似《Auto-encoding variational bayes》中变分自动编码器(variational autoencoders,VAE),然而他们使用的是KL散度惩罚的方法在隐藏层编码向量上强加一个先验分布,本文使用的是对抗训练方法去实现该目的,即让隐藏层编码向量的聚合后验能够匹配先验分布。VAE是最小化关于$\mathbf{x}$的负log似然上边界

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181228135833595-915372434.png)</center> 这里聚合后验$q(\mathbf{z})$定义与式子1中一样,假设$q(\mathbf{z}|\mathbf{x})$是高斯分布,$p(\mathbf{z})$是任意分布。变分边界包含三个部分:第一项可以被认为是AE的重构项,第二项和第三项可以看成是正则项。在没有正则项的时候,该模型简单就是个AE。然而在有正则项的时候,VAE学到的隐藏层表征是与$p(\mathbf{z})$兼容的。损失函数的第二项鼓励后验分布有较大变化,而第三项是最小化聚合后验$q(\mathbf{z})$和先验$p(\mathbf{z})$之间的交叉熵。式子2中KL散度或者交叉熵鼓励$q(\mathbf{z})$能与$p(\mathbf{z})$相匹配。而在AAE中,作者将后面两项替换成一个对抗学习的过程,从而鼓励$q(\mathbf{z})$能与$p(\mathbf{z})$整个分布相匹配。

该部分中,将AAE与VAE在编码分布$p(\mathbf{z})$上插入特定先验的能力做对比。

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181228140142463-1222595479.png)</center> 如图2a,展示的是在测试数据上的2维编码空间$\mathbf{z}$,基于MNIST数据集上训练AAE,并在隐藏层编码$\mathbf{z}$上强加一个高斯分布。学到的流行显示不同类别之间明显的过渡,编码空间被填充并且没有空洞存在。实际上,编码空间中明显的过渡指的是在位于数据流行上的$\mathbf{z}$内插值生成的图像(图2e)。图2c显示VAE的编码空间与AAE有相同结构。我们可以发现这种情况下VAE大致是匹配2D的高斯分布形态。然而,没有数据点映射到几个编码空间的局部区域意味着VAE不能如AAE一样很好的抓取数据流行。

图2b和图2d表现的是AAE和AVE的编码空间,其中插入的分布是10个2D高斯混合分布。AAE成功抓取带有先验分布的聚合后验(图2b);而VAE表现出与10个组件高斯混合的强烈差别,即VAE更多强调匹配的分布模式(图2d)。 基于VAE和AAE之间一个重要的差别是在VAE中,为了通过MC采样对KL散度进行BP,需要得到准确的先验分布的函数形式。而在AAE中,只需要能从先验分布中进行采样就能让$q(\mathbf{z})$匹配$p(\mathbf{z})$。后面会介绍AAE还能插入复杂的分布(如swiss roll分布)而并不需要该函数的准确表现形式。

1.2 与GAN和GMMN的关系

1.3 在对抗正则中插入标签信息

在该场景中,数据是标注过的,可以将标签信息插入到对抗训练过程中,以更好的塑造隐藏层编码的分布。在该部分中,介绍如何使用部分或者所有标签信息来更好的正则化AE的潜在表征。为了介绍该结构,先返回图2b,其中AAE是拟合10个成分2维的混合高斯分布。现在让混合高斯中每个成分表示MNIST中每个标签。

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181228174351553-251512682.png) ![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181228174357846-486190958.png)</center> 图3是半监督方法的训练过程。这里增加了一个one-hot向量到判别网络的输入部分,以将标签与分布模式相结合。该one-hot向量扮演着(给定类别标签基础上)该判别网络的决策面。该one-hot向量有一个额外的无标签样本类别。例如在图2b和4a中,一个10成分的2D高斯混合模型,one-hot向量有11个类别。前面10个类别对应混合模型中每个独立的决策面。额外的one-hot向量对应无标签训练样本点(如生成器生成的样本)。 <center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181229090536435-922669299.png)</center> 当一个无标签样本点出现在该模型中,额外的类别就会得到响应,以选择整个高斯混合分布的决策面: > * 在对抗训练的正阶段,通过one-hot将高斯混合模型生成的样本的标签传给判别器。这些正样本来自混合高斯模型,而不是某个具体的类别; > * 在对抗训练的负阶段,通过one-hot将生成器生成的样本的标签给判别器。这些负样本来自生成器。

图4a中展现的是基于一个AAE的隐藏层表征,该AAE是基于10k个标记的MNIST样本可40K个无标签的MNIST样本,10个成分的2D高斯混合模型上训练的。此时,先验中第i个混合惩罚以半监督方式与第i个类别相关。图4b展示的是前三个混合成分的流行。注意到每个混合成分的类型表征是很一致的,且与各自的类相独立。例如,图4b中所有的左上区域对应于直立书写样式,右下区域对应于数字的倾斜书写样式。

该方法可以扩展到任意分布而不需要参数控制,如将MNIST数据映射到一个“swiss roll”(如条件高斯分布,其均值是均匀分布的,其长度为一个swiss roll的轴)。图4c是编码空间$\mathbf{z}$的展示,图4d是沿着swiss roll轴前进生成的图像。

2 AAE的似然分析

本节使用《Generative adversarial nets》中描述的评估方法,比较该模型在MNIST和toronto人脸数据集(TFD)上生成图像的能力来测量AAE作为生成模型捕获数据分布的能力。

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181229095234653-2022461310.png)</center> 图5展示的就是基于训练好的AAE生成的样本。在[tfd.gif](http://www.comm.utoronto.ca/~makhzani/adv_ae/tfd.gif)这里是学到的TFD流行。为了鉴定模型是否过拟合,在最后一列展现的是以欧式距离计算最近的训练集样本。

通过在测试集上计算AAE的log似然来对其性能进行评估。不过因为使用似然函数不直观,不能直接计算图片的概率,所以这里使用先《Generative adversarial nets》中描述的方法计算真实对数似然的下界。用高斯Parzen窗口(核密度估计器)去拟合10000个 从模型生成的样本,并计算此分布下的测试数据的可能性。parzen窗口中自由参数$\sigma$是通过交叉验证选择的。

<center/>![](https://img2018.cnblogs.com/blog/441382/201812/441382-20181229101751995-1757801910.png)</center> 表1计算在真实数据MNIST和TFD上,AAE和其他如DBN,堆叠CAE,深度GSN,GAN和GMMN+AE模型的对比结果。 注意到parzen窗口是在真实log似然上评估下边界,略。。。

3 有监督AAE

4 半监督AAE

5 基于AAE的无监督聚类

6 基于AAE的维度约间

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!