Cross Entropy in PyTorch

前端 未结 3 475
时光说笑
时光说笑 2020-12-13 02:47

I\'m a bit confused by the cross entropy loss in PyTorch.

Considering this example:

import torch
import          


        
3条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-13 03:05

    In your example you are treating output [0, 0, 0, 1] as probabilities as required by the mathematical definition of cross entropy. But PyTorch treats them as outputs, that don’t need to sum to 1, and need to be first converted into probabilities for which it uses the softmax function.

    So H(p, q) becomes:

    H(p, softmax(output))
    

    Translating the output [0, 0, 0, 1] into probabilities:

    softmax([0, 0, 0, 1]) = [0.1749, 0.1749, 0.1749, 0.4754]
    

    whence:

    -log(0.4754) = 0.7437
    

提交回复
热议问题