How to plot scikit learn classification report?

后端 未结 10 1998
广开言路
广开言路 2020-12-04 18:20

Is it possible to plot with matplotlib scikit-learn classification report?. Let\'s assume I print the classification report like this:

print \'\\n*Classifica         


        
10条回答
  •  时光取名叫无心
    2020-12-04 18:42

    This works for me, pieced it together from the top answer above, also, i cannot comment but THANKS all for this thread, it helped a LOT!

    def plot_classification_report(cr, title='Classification report ', with_avg_total=False, cmap=plt.cm.Blues):
        lines = cr.split('\n')
        classes = []
        plotMat = []
        for line in lines[2 : (len(lines) - 6)]: rt
            t = line.split()
            classes.append(t[0])
            v = [float(x) for x in t[1: len(t) - 1]]
            plotMat.append(v)
    
        if with_avg_total:
            aveTotal = lines[len(lines) - 1].split()
            classes.append('avg/total')
            vAveTotal = [float(x) for x in t[1:len(aveTotal) - 1]]
            plotMat.append(vAveTotal)
    
        plt.figure(figsize=(12,48))
        #plt.imshow(plotMat, interpolation='nearest', cmap=cmap) THIS also works but the scale is not good neither the colors for many classes(200)
        #plt.colorbar()
    
        plt.title(title)
        x_tick_marks = np.arange(3)
        y_tick_marks = np.arange(len(classes))
        plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
        plt.yticks(y_tick_marks, classes)
        plt.tight_layout()
        plt.ylabel('Classes')
        plt.xlabel('Measures')
        import seaborn as sns
        sns.heatmap(plotMat, annot=True) 
    
    After this, make sure class labels don't contain any space due the splits
    reportstr = classification_report(true_classes, y_pred,target_names=class_labels_no_spaces)
    
    plot_classification_report(reportstr)
    

提交回复
热议问题