How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits

前端 未结 3 1735
广开言路
广开言路 2020-12-08 05:32

I am starting to use tensorflow (coming from Caffe), and I am using the loss sparse_softmax_cross_entropy_with_logits. The function accepts labels like 0,

3条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-08 06:18

    Specifically for binary classification, there is weighted_cross_entropy_with_logits, that computes weighted softmax cross entropy.

    sparse_softmax_cross_entropy_with_logits is tailed for a high-efficient non-weighted operation (see SparseSoftmaxXentWithLogitsOp which uses SparseXentEigenImpl under the hood), so it's not "pluggable".

    In multi-class case, your option is either switch to one-hot encoding or use tf.losses.sparse_softmax_cross_entropy loss function in a hacky way, as already suggested, where you will have to pass the weights depending on the labels in a current batch.

提交回复
热议问题