How to plot scikit learn classification report?

后端 未结 10 2032
广开言路
广开言路 2020-12-04 18:20

Is it possible to plot with matplotlib scikit-learn classification report?. Let\'s assume I print the classification report like this:

print \'\\n*Classifica         


        
10条回答
  •  清歌不尽
    2020-12-04 18:56

    I just wrote a function plot_classification_report() for this purpose. Hope it helps. This function takes out put of classification_report function as an argument and plot the scores. Here is the function.

    def plot_classification_report(cr, title='Classification report ', with_avg_total=False, cmap=plt.cm.Blues):
    
        lines = cr.split('\n')
    
        classes = []
        plotMat = []
        for line in lines[2 : (len(lines) - 3)]:
            #print(line)
            t = line.split()
            # print(t)
            classes.append(t[0])
            v = [float(x) for x in t[1: len(t) - 1]]
            print(v)
            plotMat.append(v)
    
        if with_avg_total:
            aveTotal = lines[len(lines) - 1].split()
            classes.append('avg/total')
            vAveTotal = [float(x) for x in t[1:len(aveTotal) - 1]]
            plotMat.append(vAveTotal)
    
    
        plt.imshow(plotMat, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        x_tick_marks = np.arange(3)
        y_tick_marks = np.arange(len(classes))
        plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
        plt.yticks(y_tick_marks, classes)
        plt.tight_layout()
        plt.ylabel('Classes')
        plt.xlabel('Measures')
    

    For the example classification_report provided by you. Here are the code and output.

    sampleClassificationReport = """             precision    recall  f1-score   support
    
              1       0.62      1.00      0.76        66
              2       0.93      0.93      0.93        40
              3       0.59      0.97      0.73        67
              4       0.47      0.92      0.62       272
              5       1.00      0.16      0.28       413
    
    avg / total       0.77      0.57      0.49       858"""
    
    
    plot_classification_report(sampleClassificationReport)
    

    Here is how to use it with sklearn classification_report output:

    from sklearn.metrics import classification_report
    classificationReport = classification_report(y_true, y_pred, target_names=target_names)
    
    plot_classification_report(classificationReport)
    

    With this function, you can also add the "avg / total" result to the plot. To use it just add an argument with_avg_total like this:

    plot_classification_report(classificationReport, with_avg_total=True)
    

提交回复
热议问题