How to plot scikit learn classification report?

后端 未结 10 2029
广开言路
广开言路 2020-12-04 18:20

Is it possible to plot with matplotlib scikit-learn classification report?. Let\'s assume I print the classification report like this:

print \'\\n*Classifica         


        
10条回答
  •  日久生厌
    2020-12-04 18:38

    If you just want to plot the classification report as a bar chart in a Jupyter notebook, you can do the following.

    # Assuming that classification_report, y_test and predictions are in scope...
    import pandas as pd
    
    # Build a DataFrame from the classification_report output_dict.
    report_data = []
    for label, metrics in classification_report(y_test, predictions, output_dict=True).items():
        metrics['label'] = label
        report_data.append(metrics)
    
    report_df = pd.DataFrame(
        report_data, 
        columns=['label', 'precision', 'recall', 'f1-score', 'support']
    )
    
    # Plot as a bar chart.
    report_df.plot(y=['precision', 'recall', 'f1-score'], x='label', kind='bar')
    

    One issue with this visualisation is that imbalanced classes are not obvious, but are important in interpreting the results. One way to represent this is to add a version of the label that includes the number of samples (i.e. the support):

    # Add a column to the DataFrame.
    report_df['labelsupport'] = [f'{label} (n={support})' 
                                 for label, support in zip(report_df.label, report_df.support)]
    
    # Plot the chart the same way, but use `labelsupport` as the x-axis.
    report_df.plot(y=['precision', 'recall', 'f1-score'], x='labelsupport', kind='bar')
    

提交回复
热议问题