Different loss function for validation set in Keras

淺唱寂寞╮ 提交于 2019-11-30 18:00:59

问题


I have unbalanced training dataset, thats why I built custom weighted categorical cross entropy loss function. But the problem is my validation set is balanced one and I want to use the regular categorical cross entropy loss. So can I pass different loss function for validation set within Keras? I mean the wighted one for training and regular one for validation set?

def weighted_loss(y_pred, y_ture):
 '
 '
 '


return loss

model.compile(loss= weighted_loss, metric='accuracy')

回答1:


You can try the backend function K.in_train_phase(), which is used by the Dropout and BatchNormalization layers to implement different behaviors in training and validation.

def custom_loss(y_true, y_pred):
    weighted_loss = ... # your implementation of weighted crossentropy loss
    unweighted_loss = K.sparse_categorical_crossentropy(y_true, y_pred)
    return K.in_train_phase(weighted_loss, unweighted_loss)

The first argument of K.in_train_phase() is the tensor used in training phase, and the second is the one used in test phase.

For example, if we set weighted_loss to 0 (just to verify the effect of K.in_train_phase() function):

def custom_loss(y_true, y_pred):
    weighted_loss = 0 * K.sparse_categorical_crossentropy(y_true, y_pred)
    unweighted_loss = K.sparse_categorical_crossentropy(y_true, y_pred)
    return K.in_train_phase(weighted_loss, unweighted_loss)

model = Sequential([Dense(100, activation='relu', input_shape=(100,)), Dense(1000, activation='softmax')])
model.compile(optimizer='adam', loss=custom_loss)
model.outputs[0]._uses_learning_phase = True  # required if no dropout or batch norm in the model

X = np.random.rand(1000, 100)
y = np.random.randint(1000, size=1000)
model.fit(X, y, validation_split=0.1)

Epoch 1/10
900/900 [==============================] - 1s 868us/step - loss: 0.0000e+00 - val_loss: 6.9438

As you can see, the loss in training phase is indeed the one multiplied by 0.

Note that if there's no dropout or batch norm in your model, you'll need to manually "turn on" the _uses_learning_phase boolean switch, otherwise K.in_train_phase() will have no effect by default.




回答2:


The validation loss function is just a metric and actually not needed for training. It's there because it make sense to compare the metrics which your network is actually optimzing on. So you can add any other loss function as metric during compilation and you'll see it during training.



来源:https://stackoverflow.com/questions/52107555/different-loss-function-for-validation-set-in-keras

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!