Acknowledge
这篇文章来自:https://www.cnblogs.com/shine-lee/archive/2019/12/12/12032066.html ,作者:@进击的小学生。从这篇文章来看,博主是个有科研情怀的人。我对这篇文章进行重编辑,以便阅读起来更清晰。
乍一看到某个问题,你会觉得很简单,其实你并没有理解其复杂性。当你把问题搞清楚之后,又会发现真的很复杂,于是你就拿出一套复杂的方案来。实际上,你的工作只做了一半,大多数人也都会到此为止……。但是,真正伟大的人还会继续向前,直至找到问题的关键和深层次原因,然后再拿出一个优雅的、堪称完美的有效方案。
—— from 乔布斯
摘要
损失函数的选择和设计要能表达你所期望的模型所具有的性质与倾向 ,本文分别从损失函数角度 和softmax反向传播角度 两个角度来直观地解释为何多分类问题选用交叉熵损失而不使用均方误差损失,具有一定的启发意义。
交叉熵损失与均方误差损失
常规分类网络最后的softmax层如下图所示,传统机器学习方法以此类比,
一共有k k k 个类。令网络的输出为[ y ^ 1 , … , y ^ k ] [\hat{y}_1, \dots, \hat{y}_k] [ y ^ 1 , … , y ^ k ] ,对应每个类别的预测概率。令label为[ y 1 , … , y k ] [y_1, \dots, y_k] [ y 1 , … , y k ] ,采用one-hot热编码形式,对于某个类别为p p p 的样例,其label中的y p = 1 y_p=1 y p = 1 ,其余y 1 , … , y p − 1 , y p + 1 , … , y k y_1, \dots, y_{p-1}, y_{p+1}, \dots, y_k y 1 , … , y p − 1 , y p + 1 , … , y k 均为0。
对这个样本,交叉熵 cross entropy 损失为L = − ( y 1 log y ^ 1 + ⋯ + y k log y ^ k ) = − y p log y ^ p = − log y ^ p ,
\begin{aligned}
L &= -(y_1 \log \hat{y}_1 + \cdots + y_k \log \hat{y}_k)\\
&=-y_p \log \hat{y}_p\\
&=-\log \hat{y}_p,
\end{aligned}
L = − ( y 1 log y ^ 1 + ⋯ + y k log y ^ k ) = − y p log y ^ p = − log y ^ p ,
而均方误差损失 mean squared error,简称MSE 为L = ( y 1 − y ^ 1 ) 2 + ⋯ + ( y k − y ^ k ) 2 = ( 1 − y ^ p ) 2 + ( y ^ 1 2 + ⋯ + y ^ p − 1 2 + y ^ p + 1 2 + y ^ k 2 ) 。
\begin{aligned}
L &= (y_1 - \hat{y}_1)^2 + \cdots + (y_k - \hat{y}_k)^2\\
&= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \cdots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \hat{y}_k^2)
\end{aligned}。
L = ( y 1 − y ^ 1 ) 2 + ⋯ + ( y k − y ^ k ) 2 = ( 1 − y ^ p ) 2 + ( y ^ 1 2 + ⋯ + y ^ p − 1 2 + y ^ p + 1 2 + y ^ k 2 ) 。
损失函数角度
对比交叉熵和均方误差损失,可以发现,
交叉熵只与label类别有关,y ^ p \hat{y}_p y ^ p 越趋近于1 1 1 越好。
均方误差不仅与y ^ p \hat{y}_p y ^ p 有关,还与其它项有关,它希望( y ^ 1 2 + ⋯ + y ^ p − 1 2 + y ^ p + 1 2 + y ^ k 2 ) (\hat{y}_1^2 + \cdots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \hat{y}_k^2) ( y ^ 1 2 + ⋯ + y ^ p − 1 2 + y ^ p + 1 2 + y ^ k 2 ) 越小越好。
多分类问题中,对于类别之间的相关性,我们缺乏先验。 例如,
尽管与“狗”相比,“猫”和“老虎”之间的相似度更高,但是这种关系在样本标记时是难以量化的,所给label只能是one-hot形式。
在这个前提下,均方误差损失可能会给出错误的指示,比如猫、老虎、狗的3分类问题,label为[ 1 , 0 , 0 ] [1,0,0] [ 1 , 0 , 0 ] ,在均方误差看来,预测为[ 0.8 , 0.1 , 0.1 ] [0.8,0.1,0.1] [ 0 . 8 , 0 . 1 , 0 . 1 ] 要比[ 0.8 , 0.15 , 0.05 ] [0.8,0.15,0.05] [ 0 . 8 , 0 . 1 5 , 0 . 0 5 ] 要好,老虎和狗居然获得同等的分数,这有悖我们的常识。
而对交叉熵损失,既然类别间复杂的相似度矩阵是难以量化的,那么就索性只关注样本所属的类别,这显得更加合理。
softmax反向传播角度
softmax的作用是将( − ∞ , + ∞ ) (-\infty, +\infty) ( − ∞ , + ∞ ) 的几个实数映射到( 0 , 1 ) (0,1) ( 0 , 1 ) 之间且之和为1 1 1 ,以获得某种概率解释。
令softmax函数的输入为z z z ,输出为y ^ \hat{y} y ^ ,对节点p p p 有y ^ p = e z p ∑ i = 1 k e z i ,
\hat{y}_p = \frac{e^{z_p}}{\sum^k_{i=1} e^{z_i}},
y ^ p = ∑ i = 1 k e z i e z p ,
y ^ p \hat{y}_p y ^ p 不仅与z p z_p z p 有关,还与{ z i ∣ i ≠ p } \{z_i | i \neq p \} { z i ∣ i = p } 有关,这里仅看z p z_p z p ,则有∂ y ^ p ∂ z p = y ^ p ( 1 − y ^ p ) 。
\frac{\partial{\hat{y}_p}}{\partial{z_p}} = \hat{y}_p (1 - \hat{y}_p)。
∂ z p ∂ y ^ p = y ^ p ( 1 − y ^ p ) 。
对均方误差损失,有∂ L ∂ y ^ p = − 2 ( 1 − y ^ p ) = 2 ( y ^ p − 1 ) ,
\frac{\partial{L}}{\partial{\hat{y}_p}} = -2(1-\hat{y}_p) = 2(\hat{y}_p-1),
∂ y ^ p ∂ L = − 2 ( 1 − y ^ p ) = 2 ( y ^ p − 1 ) ,
根据链式法则,进一步有∂ L ∂ z ^ p = ∂ L ∂ y ^ p ⋅ ∂ y ^ p ∂ z ^ p = − 2 y ^ p ( 1 − y ^ p ) 2 ,
\frac{\partial{L}}{\partial{\hat{z}_p}} = \frac{\partial{L}}{\partial{\hat{y}_p}} \cdot \frac{\partial{\hat{y}_p}}{\partial{\hat{z}_p}} = -2\hat{y}_p ( 1-\hat{y}_p)^2,
∂ z ^ p ∂ L = ∂ y ^ p ∂ L ⋅ ∂ z ^ p ∂ y ^ p = − 2 y ^ p ( 1 − y ^ p ) 2 ,
当 y ^ p = 0 \hat{y}_p = 0 y ^ p = 0 时分类错误,但偏导数为0 0 0 ,权重不会更新,这显然不对——分类越错误越需要对权重进行更新。
而对交叉熵损失,有∂ L ∂ y ^ p = − 1 y ^ p ,
\frac{\partial{L}}{\partial{\hat{y}_p}} = - \frac{1}{\hat{y}_p},
∂ y ^ p ∂ L = − y ^ p 1 ,
进一步有∂ L ∂ z ^ p = L ∂ y ^ p ⋅ ∂ y ^ p ∂ z ^ p = y ^ p − 1
\frac{\partial{L}}{\partial{\hat{z}_p}} = \frac{L}{\partial{\hat{y}_p}} \cdot \frac{\partial{\hat{y}_p}}{\partial{\hat{z}_p}} = \hat{y}_p - 1
∂ z ^ p ∂ L = ∂ y ^ p L ⋅ ∂ z ^ p ∂ y ^ p = y ^ p − 1
恰巧将y ^ p ( 1 − y ^ p ) \hat{y}_p(1-\hat{y}_p) y ^ p ( 1 − y ^ p ) 中的y ^ p \hat{y}_p y ^ p 消掉,避免了上述情形的发生,且y ^ p \hat{y}_p y ^ p 越接近于1 1 1 ,偏导越接近于0 0 0 ,即分类越正确越不需要更新权重,这与我们的期望相符。
总结
综上,对分类问题而言,无论从损失函数角度还是softmax反向传播角度,交叉熵都比均方误差要好。这篇文章思路清晰,分析详细,对入门者有较好的启发意义。