直观理解为什么分类问题用交叉熵损失而不用均方误差损失?

谁说我不能喝 提交于 2020-01-24 03:55:15

Acknowledge

这篇文章来自:https://www.cnblogs.com/shine-lee/archive/2019/12/12/12032066.html,作者:@进击的小学生。从这篇文章来看,博主是个有科研情怀的人。我对这篇文章进行重编辑,以便阅读起来更清晰。

乍一看到某个问题,你会觉得很简单,其实你并没有理解其复杂性。当你把问题搞清楚之后,又会发现真的很复杂,于是你就拿出一套复杂的方案来。实际上,你的工作只做了一半,大多数人也都会到此为止……。但是,真正伟大的人还会继续向前,直至找到问题的关键和深层次原因,然后再拿出一个优雅的、堪称完美的有效方案。
—— from 乔布斯

摘要

损失函数的选择和设计要能表达你所期望的模型所具有的性质与倾向,本文分别从损失函数角度softmax反向传播角度两个角度来直观地解释为何多分类问题选用交叉熵损失而不使用均方误差损失,具有一定的启发意义。

交叉熵损失与均方误差损失

常规分类网络最后的softmax层如下图所示,传统机器学习方法以此类比,
在这里插入图片描述
一共有kk个类。令网络的输出为[y^1,,y^k][\hat{y}_1, \dots, \hat{y}_k],对应每个类别的预测概率。令label为[y1,,yk][y_1, \dots, y_k],采用one-hot热编码形式,对于某个类别为pp的样例,其label中的yp=1y_p=1,其余y1,,yp1,yp+1,,yky_1, \dots, y_{p-1}, y_{p+1}, \dots, y_k均为0。

对这个样本,交叉熵 cross entropy损失为
L=(y1logy^1++yklogy^k)=yplogy^p=logy^p \begin{aligned} L &= -(y_1 \log \hat{y}_1 + \cdots + y_k \log \hat{y}_k)\\ &=-y_p \log \hat{y}_p\\ &=-\log \hat{y}_p, \end{aligned}
均方误差损失 mean squared error,简称MSE
L=(y1y^1)2++(yky^k)2=(1y^p)2+(y^12++y^p12+y^p+12+y^k2) \begin{aligned} L &= (y_1 - \hat{y}_1)^2 + \cdots + (y_k - \hat{y}_k)^2\\ &= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \cdots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \hat{y}_k^2) \end{aligned}。

损失函数角度

对比交叉熵和均方误差损失,可以发现,

  • 交叉熵只与label类别有关,y^p\hat{y}_p越趋近于11越好。
  • 均方误差不仅与y^p\hat{y}_p有关,还与其它项有关,它希望(y^12++y^p12+y^p+12+y^k2)(\hat{y}_1^2 + \cdots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \hat{y}_k^2)越小越好。

多分类问题中,对于类别之间的相关性,我们缺乏先验。 例如,

尽管与“狗”相比,“猫”和“老虎”之间的相似度更高,但是这种关系在样本标记时是难以量化的,所给label只能是one-hot形式。

在这个前提下,均方误差损失可能会给出错误的指示,比如猫、老虎、狗的3分类问题,label为[1,0,0][1,0,0],在均方误差看来,预测为[0.8,0.1,0.1][0.8,0.1,0.1]要比[0.8,0.15,0.05][0.8,0.15,0.05]要好,老虎和狗居然获得同等的分数,这有悖我们的常识。

而对交叉熵损失,既然类别间复杂的相似度矩阵是难以量化的,那么就索性只关注样本所属的类别,这显得更加合理。

softmax反向传播角度

softmax的作用是将(,+)(-\infty, +\infty)的几个实数映射到(0,1)(0,1)之间且之和为11,以获得某种概率解释。

令softmax函数的输入为zz,输出为y^\hat{y},对节点pp
y^p=ezpi=1kezi \hat{y}_p = \frac{e^{z_p}}{\sum^k_{i=1} e^{z_i}},

y^p\hat{y}_p不仅与zpz_p有关,还与{ziip}\{z_i | i \neq p \}有关,这里仅看zpz_p,则有
y^pzp=y^p(1y^p) \frac{\partial{\hat{y}_p}}{\partial{z_p}} = \hat{y}_p (1 - \hat{y}_p)。

对均方误差损失,有
Ly^p=2(1y^p)=2(y^p1) \frac{\partial{L}}{\partial{\hat{y}_p}} = -2(1-\hat{y}_p) = 2(\hat{y}_p-1),
根据链式法则,进一步有
Lz^p=Ly^py^pz^p=2y^p(1y^p)2 \frac{\partial{L}}{\partial{\hat{z}_p}} = \frac{\partial{L}}{\partial{\hat{y}_p}} \cdot \frac{\partial{\hat{y}_p}}{\partial{\hat{z}_p}} = -2\hat{y}_p ( 1-\hat{y}_p)^2,

y^p=0\hat{y}_p = 0时分类错误,但偏导数为00,权重不会更新,这显然不对——分类越错误越需要对权重进行更新。

而对交叉熵损失,有
Ly^p=1y^p \frac{\partial{L}}{\partial{\hat{y}_p}} = - \frac{1}{\hat{y}_p},
进一步有
Lz^p=Ly^py^pz^p=y^p1 \frac{\partial{L}}{\partial{\hat{z}_p}} = \frac{L}{\partial{\hat{y}_p}} \cdot \frac{\partial{\hat{y}_p}}{\partial{\hat{z}_p}} = \hat{y}_p - 1
恰巧将y^p(1y^p)\hat{y}_p(1-\hat{y}_p)中的y^p\hat{y}_p消掉,避免了上述情形的发生,且y^p\hat{y}_p越接近于11,偏导越接近于00,即分类越正确越不需要更新权重,这与我们的期望相符。

总结

综上,对分类问题而言,无论从损失函数角度还是softmax反向传播角度,交叉熵都比均方误差要好。这篇文章思路清晰,分析详细,对入门者有较好的启发意义。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!