How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits

前端 未结 3 1729
广开言路
广开言路 2020-12-08 05:32

I am starting to use tensorflow (coming from Caffe), and I am using the loss sparse_softmax_cross_entropy_with_logits. The function accepts labels like 0,

3条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2020-12-08 06:32

    import  tensorflow as tf
    import numpy as np
    
    np.random.seed(123)
    sess = tf.InteractiveSession()
    
    # let's say we have the logits and labels of a batch of size 6 with 5 classes
    logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
    labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)
    
    # specify some class weightings
    class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])
    
    # specify the weights for each sample in the batch (without having to compute the onehot label matrix)
    weights = tf.gather(class_weights, labels)
    
    # compute the loss
    tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()
    

提交回复
热议问题