How to output per-class accuracy in Keras?

前端 未结 3 773
有刺的猬
有刺的猬 2020-12-25 13:12

Caffe can not only print overall accuracy, but also per-class accuracy.

In Keras log, there\'s only overall accuracy. It\'s hard for me to calculate the separate cla

3条回答
  •  鱼传尺愫
    2020-12-25 14:10

    You are probably looking to use a callback, which you can easily add to the model.fit() call.

    For example, you can define your own class using the keras.callbacks.Callback interface. I recommend using the on_epoch_end() function since it will format nicely inside of your training summary if you decide to print with that verbosity setting. Please note that this particular code block is set to use 3 classes, but you can of course change it to your desired number.

    # your class labels
    classes = ["class_1","class_2", "class_3"]
    
    class AccuracyCallback(tf.keras.callbacks.Callback):
    
        def __init__(self, test_data):
            self.test_data = test_data
    
        def on_epoch_end(self, epoch, logs=None):
            x_data, y_data = self.test_data
    
            correct = 0
            incorrect = 0
    
            x_result = self.model.predict(x_data, verbose=0)
    
            x_numpy = []
    
            for i in classes:
                self.class_history.append([])
    
            class_correct = [0] * len(classes)
            class_incorrect = [0] * len(classes)
    
            for i in range(len(x_data)):
                x = x_data[i]
                y = y_data[i]
    
                res = x_result[i]
    
                actual_label = np.argmax(y)
                pred_label = np.argmax(res)
    
                if(pred_label == actual_label):
                    x_numpy.append(["cor:", str(y), str(res), str(pred_label)])     
                    class_correct[actual_label] += 1   
                    correct += 1
                else:
                    x_numpy.append(["inc:", str(y), str(res), str(pred_label)])
                    class_incorrect[actual_label] += 1
                    incorrect += 1
    
            print("\tCorrect: %d" %(correct))
            print("\tIncorrect: %d" %(incorrect))
    
            for i in range(len(classes)):
                tot = float(class_correct[i] + class_incorrect[i])
                class_acc = -1
                if (tot > 0):
                    class_acc = float(class_correct[i]) / tot
    
                print("\t%s: %.3f" %(classes[i],class_acc)) 
    
            acc = float(correct) / float(correct + incorrect)  
    
            print("\tCurrent Network Accuracy: %.3f" %(acc))
    

    Then, you are going to want to configure your new callback to your model fit. Assuming your validation data (val_data) is some tuple pair, you can use the following:

    accuracy_callback = AccuracyCallback(val_data)
    
    # you can use the history if desired
    history = model.fit( x=_, y=_, verbose=1, 
               epochs=_, shuffle=_, validation_data = val_data,
               callbacks=[accuracy_callback], batch_size=_
             )
    

    Please note that the _ indicates values likely to change based on your configuration

提交回复
热议问题