Implement custom loss function in Tensorflow 2.0

拟墨画扇 提交于 2019-12-12 23:57:02

问题


I'm building a model for Time series classification. The data is very unbalanced so I've decided to use a weighted cross entropy function as my loss.

Tensorflow provides tf.nn.weighted_cross_entropy_with_logits but I'm not sure how to use it in TF 2.0. Because my model is build using tf.keras API I was thinking about creating my custom loss function like this:

pos_weight=10
def weighted_cross_entropy_with_logits(y_true,y_pred):
  return tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,pos_weight)

# .....
model.compile(loss=weighted_cross_entropy_with_logits,optimizer="adam",metrics=["acc"])

My question is: is there a way to use tf.nn.weighted_cross_entropy_with_logits with tf.keras API directly?


回答1:


You can pass the class weights directly to the model.fit function.

class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

Such as:

{
    0: 0.31, 
    1: 0.33, 
    2: 0.36, 
    3: 0.42, 
    4: 0.48
}

Source



来源:https://stackoverflow.com/questions/57414114/implement-custom-loss-function-in-tensorflow-2-0

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!