How to plot scikit learn classification report?

后端 未结 10 2008
广开言路
广开言路 2020-12-04 18:20

Is it possible to plot with matplotlib scikit-learn classification report?. Let\'s assume I print the classification report like this:

print \'\\n*Classifica         


        
10条回答
  •  粉色の甜心
    2020-12-04 18:46

    This is my simple solution, using seaborn heatmap

    import seaborn as sns
    import numpy as np
    from sklearn.metrics import precision_recall_fscore_support
    import matplotlib.pyplot as plt
    
    y = np.random.randint(low=0, high=10, size=100)
    y_p = np.random.randint(low=0, high=10, size=100)
    
    def plot_classification_report(y_tru, y_prd, figsize=(10, 10), ax=None):
    
        plt.figure(figsize=figsize)
    
        xticks = ['precision', 'recall', 'f1-score', 'support']
        yticks = list(np.unique(y_tru))
        yticks += ['avg']
    
        rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
        avg = np.mean(rep, axis=0)
        avg[-1] = np.sum(rep[:, -1])
        rep = np.insert(rep, rep.shape[0], avg, axis=0)
    
        sns.heatmap(rep,
                    annot=True, 
                    cbar=False, 
                    xticklabels=xticks, 
                    yticklabels=yticks,
                    ax=ax)
    
    plot_classification_report(y, y_p)
    

    This is how the plot will look like

提交回复
热议问题