How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits

只谈情不闲聊 提交于 2019-11-28 05:33:50
import  tensorflow as tf
import numpy as np

np.random.seed(123)
sess = tf.InteractiveSession()

# let's say we have the logits and labels of a batch of size 6 with 5 classes
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)

# specify some class weightings
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])

# specify the weights for each sample in the batch (without having to compute the onehot label matrix)
weights = tf.gather(class_weights, labels)

# compute the loss
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()

Specifically for binary classification, there is weighted_cross_entropy_with_logits, that computes weighted softmax cross entropy.

sparse_softmax_cross_entropy_with_logits is tailed for a high-efficient non-weighted operation (see SparseSoftmaxXentWithLogitsOp which uses SparseXentEigenImpl under the hood), so it's not "pluggable".

In multi-class case, your option is either switch to one-hot encoding or use tf.losses.sparse_softmax_cross_entropy loss function in a hacky way, as already suggested, where you will have to pass the weights depending on the labels in a current batch.

Alexa Greenberg

The class weights are multiplied by the logits, so that still works for sparse_softmax_cross_entropy_with_logits. Refer to this solution for "Loss function for class imbalanced binary classifier in Tensor flow."

As a side note, you can pass weights directly into sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None)

This method is for cross-entropy loss using

tf.nn.sparse_softmax_cross_entropy_with_logits.

Weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If weight is a tensor of size [batch_size], then the loss weights apply to each corresponding sample.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!