How to compute Receiving Operating Characteristic (ROC) and AUC in keras?

前端 未结 8 1822
一生所求
一生所求 2020-11-28 02:07

I have a multi output(200) binary classification model which I wrote in keras.

In this model I want to add additional metrics such as ROC and AUC but to my knowledg

8条回答
  •  旧时难觅i
    2020-11-28 02:49

    Due to that you can't calculate ROC&AUC by mini-batches, you can only calculate it on the end of one epoch. There is a solution from jamartinh, I patch the codes below for convenience:

    from sklearn.metrics import roc_auc_score
    from keras.callbacks import Callback
    class RocCallback(Callback):
        def __init__(self,training_data,validation_data):
            self.x = training_data[0]
            self.y = training_data[1]
            self.x_val = validation_data[0]
            self.y_val = validation_data[1]
    
    
        def on_train_begin(self, logs={}):
            return
    
        def on_train_end(self, logs={}):
            return
    
        def on_epoch_begin(self, epoch, logs={}):
            return
    
        def on_epoch_end(self, epoch, logs={}):
            y_pred_train = self.model.predict_proba(self.x)
            roc_train = roc_auc_score(self.y, y_pred_train)
            y_pred_val = self.model.predict_proba(self.x_val)
            roc_val = roc_auc_score(self.y_val, y_pred_val)
            print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_train,4)),str(round(roc_val,4))),end=100*' '+'\n')
            return
    
        def on_batch_begin(self, batch, logs={}):
            return
    
        def on_batch_end(self, batch, logs={}):
            return
    
    roc = RocCallback(training_data=(X_train, y_train),
                      validation_data=(X_test, y_test))
    
    model.fit(X_train, y_train, 
              validation_data=(X_test, y_test),
              callbacks=[roc])
    

    A more hackable way using tf.contrib.metrics.streaming_auc:

    import numpy as np
    import tensorflow as tf
    from sklearn.metrics import roc_auc_score
    from sklearn.datasets import make_classification
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.utils import np_utils
    from keras.callbacks import Callback, EarlyStopping
    
    
    # define roc_callback, inspired by https://github.com/keras-team/keras/issues/6050#issuecomment-329996505
    def auc_roc(y_true, y_pred):
        # any tensorflow metric
        value, update_op = tf.contrib.metrics.streaming_auc(y_pred, y_true)
    
        # find all variables created for this metric
        metric_vars = [i for i in tf.local_variables() if 'auc_roc' in i.name.split('/')[1]]
    
        # Add metric variables to GLOBAL_VARIABLES collection.
        # They will be initialized for new session.
        for v in metric_vars:
            tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)
    
        # force to update metric values
        with tf.control_dependencies([update_op]):
            value = tf.identity(value)
            return value
    
    # generation a small dataset
    N_all = 10000
    N_tr = int(0.7 * N_all)
    N_te = N_all - N_tr
    X, y = make_classification(n_samples=N_all, n_features=20, n_classes=2)
    y = np_utils.to_categorical(y, num_classes=2)
    
    X_train, X_valid = X[:N_tr, :], X[N_tr:, :]
    y_train, y_valid = y[:N_tr, :], y[N_tr:, :]
    
    # model & train
    model = Sequential()
    model.add(Dense(2, activation="softmax", input_shape=(X.shape[1],)))
    
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy', auc_roc])
    
    my_callbacks = [EarlyStopping(monitor='auc_roc', patience=300, verbose=1, mode='max')]
    
    model.fit(X, y,
              validation_split=0.3,
              shuffle=True,
              batch_size=32, nb_epoch=5, verbose=1,
              callbacks=my_callbacks)
    
    # # or use independent valid set
    # model.fit(X_train, y_train,
    #           validation_data=(X_valid, y_valid),
    #           batch_size=32, nb_epoch=5, verbose=1,
    #           callbacks=my_callbacks)
    

提交回复
热议问题