How to plot scikit learn classification report?

后端 未结 10 2047
广开言路
广开言路 2020-12-04 18:20

Is it possible to plot with matplotlib scikit-learn classification report?. Let\'s assume I print the classification report like this:

print \'\\n*Classifica         


        
10条回答
  •  孤街浪徒
    2020-12-04 18:48

    No string processing + sns.heatmap

    The following solution uses the output_dict=True option in classification_report to get a dictionary and then a heat map is drawn using seaborn to the dataframe created from the dictionary.


    import numpy as np
    import seaborn as sns
    from sklearn.metrics import classification_report
    import pandas as pd
    

    Generating data. Classes: A,B,C,D,E,F,G,H,I

    true = np.random.randint(0, 10, size=100)
    pred = np.random.randint(0, 10, size=100)
    labels = np.arange(10)
    target_names = list("ABCDEFGHI")
    

    Call classification_report with output_dict=True

    clf_report = classification_report(true,
                                       pred,
                                       labels=labels,
                                       target_names=target_names,
                                       output_dict=True)
    

    Create a dataframe from the dictionary and plot a heatmap of it.

    # .iloc[:-1, :] to exclude support
    sns.heatmap(pd.DataFrame(clf_report).iloc[:-1, :].T, annot=True)
    

提交回复
热议问题