Keras: weighted binary crossentropy

后端 未结 6 630
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-31 08:39

I tried to implement a weighted binary crossentropy with Keras, but I am not sure if the code is correct. The training output seems to be a bit confusing. After a few epochs I j

6条回答
  •  半阙折子戏
    2021-01-31 09:23

    You can use the sklearn module to automatically calculate the weights for each class like this:

    # Import
    import numpy as np
    from sklearn.utils import class_weight
    
    # Example model
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=100))
    model.add(Dense(1, activation='sigmoid'))
    
    # Use binary crossentropy loss
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    # Calculate the weights for each class so that we can balance the data
    weights = class_weight.compute_class_weight('balanced',
                                                np.unique(y_train),
                                                y_train)
    
    # Add the class weights to the training                                         
    model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)
    

    Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

提交回复
热议问题