I\'m a bit confused by the cross entropy loss in PyTorch.
Considering this example:
import torch
import
Your understanding is correct but pytorch doesn't compute cross entropy in that way. Pytorch uses the following formula.
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
= -x[class] + log(\sum_j exp(x[j]))
Since, in your scenario, x = [0, 0, 0, 1] and class = 3, if you evaluate the above expression, you would get:
loss(x, class) = -1 + log(exp(0) + exp(0) + exp(0) + exp(1))
= 0.7437
Pytorch considers natural logarithm.