How to plot confusion matrix with string axis rather than integer in python

前端 未结 4 1855
星月不相逢
星月不相逢 2020-12-02 08:18

I am following a previous thread on how to plot confusion matrix in Matplotlib. The script is as follows:

from numpy import *
import matplotlib.pyplot as plt         


        
4条回答
  •  忘掉有多难
    2020-12-02 09:05

    Here is what you want:

    from string import ascii_uppercase
    from pandas import DataFrame
    import numpy as np
    import seaborn as sn
    from sklearn.metrics import confusion_matrix
    
    y_test = np.array([1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])
    predic = np.array([1,2,4,3,5, 1,2,4,3,5, 1,2,3,4,4])
    
    columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:len(np.unique(y_test))]]
    
    confm = confusion_matrix(y_test, predic)
    df_cm = DataFrame(confm, index=columns, columns=columns)
    
    ax = sn.heatmap(df_cm, cmap='Oranges', annot=True)
    

    Example image output is here:


    If you want a more complete confusion matrix as the matlab default, with totals (last line and last column), and percents on each cell, see this module below.

    Because I scoured the internet and didn't find a confusion matrix like this one on python and I developed one with theses improvements and shared on git.


    REF:

    https://github.com/wcipriano/pretty-print-confusion-matrix

    The output example is here:

提交回复
热议问题