PyTorch中的Loss Fucntion

匿名 (未验证) 提交于 2019-12-03 00:34:01

转载:http://sshuair.com/2017/10/21/pytorch-loss/


PyTorch中的Loss Fucntion


f(x)f(x)yyf(x)f(x)WWy^y^yy

Cross Entropy

Cross Entropy(也就是交叉熵)来自香农的信息论,简单来说,交叉熵是用来衡量在给定的真实分布pkpk下,使用非真实分布qkqkf(x)f(x)

H(p,q)=k=1N(pklogqk)

最大似然估计、Negative Log Liklihood(NLL)、KL散度与Cross Entropy其实是等价的,都可以进行互相推导,当然MSE也可以用Cross Entropy进行对到出(详见Deep Learning Book P132)。

Cross Entropy可以用于分类问题,也可以用于语义分割,对于分类问题,其输出层通常为Sigmoid或者Softmax,当然也有可能直接输出加权之后的,而pytorch中与Cross Entropy相关的loss Function包括:

  • CrossEntropyLoss: combines LogSoftMax and NLLLoss in one single class,也就是说我们的网络不需要在最后一层加任何输出层,该loss Function为我们打包好了;
  • NLLLoss: 也就是negative log likelihood loss,如果需要得到log分布,则需要在网络的最后一层加上LogSoftmax
  • NLLLoss2d: 二维的negative log likelihood loss,多用于分割问题
  • BCELoss: Binary Cross Entropy,常用于二分类问题,当然也可以用于多分类问题,通常需要在网络的最后一层添加sigmoid进行配合使用,其target也就是yy值需要进行one hot编码,另外BCELoss还可以用于Multi-label classification
  • BCEWithLogitsLoss: 把Sigmoid layer 和 the BCELoss整合到了一起
  • KLDivLoss: TODO
  • PoissonNLLLoss: TODO

下面就用PyTorch对上面的Loss Function进行说明

CrossEntropyLoss

pytorch中CrossEntropyLoss是通过两个步骤计算出来的,第一步是计算log softmax,第二步是计算cross entropy(或者说是negative log likehood),CrossEntropyLoss不需要在网络的最后一层添加softmax和log层,直接输出全连接层即可。而NLLLoss则需要在定义网络的时候在最后一层添加softmax和log层

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 
import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd import numpy as np  # 预测值f(x) 构造样本,神经网络输出层 inputs_tensor = torch.FloatTensor( [  [10, 2, 1,-2,-3],  [-1,-6,-0,-3,-5],  [-5, 4, 8, 2, 1]  ])  # 真值y targets_tensor = torch.LongTensor([1,3,2]) # targets_tensor = torch.LongTensor([1])  inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)  targets_variable = autograd.Variable(targets_tensor) print('input tensor(nBatch x nClasses): {}'.format(inputs_tensor.shape)) print('target tensor shape: {}'.format(targets_tensor.shape)) 
input tensor(nBatch x nClasses): torch.Size([3, 5]) target tensor shape: torch.Size([3]) 
1 2 3 4 
loss = nn.CrossEntropyLoss() output = loss(inputs_variable, targets_variable) # output.backward() print('pytorch 内部实现的CrossEntropyLoss: {}'.format(output)) 
pytorch 内部实现的CrossEntropyLoss: Variable containing:  3.7925 [torch.FloatTensor of size 1] 

手动计算

1.log softmax

1 2 3 4 5 6 7 8 9 
# 手动计算log softmax, 计算结果的值域是[0, 1] softmax_result = F.softmax(inputs_variable) #.sum() #计算softmax print(('softmax_result(sum=1):{} \n'.format(softmax_result))) logsoftmax_result = np.log(softmax_result.data)  # 计算log,以e为底, 计算后所有的值都小于0 print('手动计算 calculate logsoftmax_result:{} \n'.format(logsoftmax_result))  # 直接调用F.log_softmax softmax_result = F.log_softmax(inputs_variable) print('F.log_softmax calculate logsoftmax_result:{} \n'.format(logsoftmax_result)) 
softmax_result(sum=1):Variable containing:  9.9953e-01  3.3531e-04  1.2335e-04  6.1413e-06  2.2593e-06  2.5782e-01  1.7372e-03  7.0083e-01  3.4892e-02  4.7221e-03  2.2123e-06  1.7926e-02  9.7875e-01  2.4261e-03  8.9251e-04 [torch.FloatTensor of size 3x5]   手动计算 calculate logsoftmax_result: -4.6717e-04 -8.0005e+00 -9.0005e+00 -1.2000e+01 -1.3000e+01 -1.3555e+00 -6.3555e+00 -3.5549e-01 -3.3555e+00 -5.3555e+00 -1.3021e+01 -4.0215e+00 -2.1476e-02 -6.0215e+00 -7.0215e+00 [torch.FloatTensor of size 3x5]   F.log_softmax calculate logsoftmax_result: -4.6717e-04 -8.0005e+00 -9.0005e+00 -1.2000e+01 -1.3000e+01 -1.3555e+00 -6.3555e+00 -3.5549e-01 -3.3555e+00 -5.3555e+00 -1.3021e+01 -4.0215e+00 -2.1476e-02 -6.0215e+00 -7.0215e+00 [torch.FloatTensor of size 3x5] 

2.手动计算loss

pytorch中NLLLoss定义如下:

loss(x,class)=x[class]

这里为什么可以这么写呢?下面用第三个样本进行解释

我们用one-hot编码后,得到真实分布概率的值px(orpmodel)为(这里一共有5类):[0,0,1,0,0]

而模型预测的每一类分布概率,也就是非真实分布的概率qx(orppred)注意:概率要求其结果为1,这里使用的是softmax计算出来的结果,而不是log softmax

Nk=1(pklogqk)

mi=1log(pmodel(yixi;θ))

将对应项目相乘即可得到最终的loss结果:

0×log(2.57821001)+0×log(1.73721003)+0×log(7.00831001)+1×log(3.48921002)+0×log(4.72211003)=1×log(3.48921002)

也就恒等于

0×1.355510+00+0×6.355510+00+1×3.55491001+0×3.355510+00+0×5.355510+00=1×3.355510+00

由于其他类别都是0,而且真实概率的一定是1,因此可以简化表示为loss(x,class)=x[class]q

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!