Cross Entropy in PyTorch

前端 未结 3 473
时光说笑
时光说笑 2020-12-13 02:47

I\'m a bit confused by the cross entropy loss in PyTorch.

Considering this example:

import torch
import          


        
3条回答
  •  误落风尘
    2020-12-13 03:06

    Your understanding is correct but pytorch doesn't compute cross entropy in that way. Pytorch uses the following formula.

    loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
                   = -x[class] + log(\sum_j exp(x[j]))
    

    Since, in your scenario, x = [0, 0, 0, 1] and class = 3, if you evaluate the above expression, you would get:

    loss(x, class) = -1 + log(exp(0) + exp(0) + exp(0) + exp(1))
                   = 0.7437
    

    Pytorch considers natural logarithm.

提交回复
热议问题