sklearn plot confusion matrix with labels

匿名 (未验证) 提交于 2019-12-03 02:46:02

问题:

I want to plot a confusion matrix to visualize the classifer's performance, but it shows only the numbers of the labels, not the labels themselves:

from sklearn.metrics import confusion_matrix import pylab as pl y_test=['business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business']  pred=array(['health', 'business', 'business', 'business', 'business',        'business', 'health', 'health', 'business', 'business', 'business',        'business', 'business', 'business', 'business', 'business',        'health', 'health', 'business', 'health'],        dtype='|S8')  cm = confusion_matrix(y_test, pred) pl.matshow(cm) pl.title('Confusion matrix of the classifier') pl.colorbar() pl.show() 

How can I add the labels (health, business..etc) to the confusion matrix?

回答1:

As hinted in this question, you have to "open" the lower-level artist API, by storing the figure and axis objects passed by the matplotlib functions you call (the fig, ax and cax variables below). You can then replace the default x- and y-axis ticks using set_xticklabels/set_yticklabels:

from sklearn.metrics import confusion_matrix  labels = ['business', 'health'] cm = confusion_matrix(y_test, pred, labels) print(cm) fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(cm) plt.title('Confusion matrix of the classifier') fig.colorbar(cax) ax.set_xticklabels([''] + labels) ax.set_yticklabels([''] + labels) plt.xlabel('Predicted') plt.ylabel('True') plt.show() 

Note that I passed the labels list to the confusion_matrix function to make sure it's properly sorted, matching the ticks.

This results in the following figure:



回答2:

You might be interested by https://github.com/pandas-ml/pandas-ml/

which implements a Python Pandas implementation of Confusion Matrix.

Some features:

  • plot confusion matrix
  • plot normalized confusion matrix
  • class statistics
  • overall statistics

Here is an example:

In [1]: from pandas_ml import ConfusionMatrix In [2]: import matplotlib.pyplot as plt  In [3]: y_test = ['business', 'business', 'business', 'business', 'business',         'business', 'business', 'business', 'business', 'business',         'business', 'business', 'business', 'business', 'business',         'business', 'business', 'business', 'business', 'business']  In [4]: y_pred = ['health', 'business', 'business', 'business', 'business',        'business', 'health', 'health', 'business', 'business', 'business',        'business', 'business', 'business', 'business', 'business',        'health', 'health', 'business', 'health']  In [5]: cm = ConfusionMatrix(y_test, y_pred)  In [6]: cm Out[6]: Predicted  business  health  __all__ Actual business         14       6       20 health            0       0        0 __all__          14       6       20  In [7]: cm.plot() Out[7]: <matplotlib.axes._subplots.AxesSubplot at 0x1093cf9b0>  In [8]: plt.show() 

In [9]: cm.print_stats() Confusion Matrix:  Predicted  business  health  __all__ Actual business         14       6       20 health            0       0        0 __all__          14       6       20   Overall Statistics:  Accuracy: 0.7 95% CI: (0.45721081772371086, 0.88106840959427235) No Information Rate: ToDo P-Value [Acc > NIR]: 0.608009812201 Kappa: 0.0 Mcnemar's Test P-Value: ToDo   Class Statistics:  Classes                                 business health Population                                    20     20 P: Condition positive                         20      0 N: Condition negative                          0     20 Test outcome positive                         14      6 Test outcome negative                          6     14 TP: True Positive                             14      0 TN: True Negative                              0     14 FP: False Positive                             0      6 FN: False Negative                             6      0 TPR: (Sensitivity, hit rate, recall)         0.7    NaN TNR=SPC: (Specificity)                       NaN    0.7 PPV: Pos Pred Value (Precision)                1      0 NPV: Neg Pred Value                            0      1 FPR: False-out                               NaN    0.3 FDR: False Discovery Rate                      0      1 FNR: Miss Rate                               0.3    NaN ACC: Accuracy                                0.7    0.7 F1 score                               0.8235294      0 MCC: Matthews correlation coefficient        NaN    NaN Informedness                                 NaN    NaN Markedness                                     0      0 Prevalence                                     1      0 LR+: Positive likelihood ratio               NaN    NaN LR-: Negative likelihood ratio               NaN    NaN DOR: Diagnostic odds ratio                   NaN    NaN FOR: False omission rate                       1      0 


回答3:

I think it's worth mentioning the use of seaborn.heatmap here.

import seaborn as sns  ax= plt.subplot() sns.heatmap(cm, annot=True, ax = ax); #annot=True to annotate cells  # labels, title and ticks ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels');  ax.set_title('Confusion Matrix');  ax.xaxis.set_ticklabels(['business', 'health']); ax.yaxis.set_ticklabels(['health', 'business']); 



回答4:

I found a function that can plot the confusion matrix which generated from sklearn.

import numpy as np   def plot_confusion_matrix(cm,                           target_names,                           title='Confusion matrix',                           cmap=None,                           normalize=True):     """     given a sklearn confusion matrix (cm), make a nice plot      Arguments     ---------     cm:           confusion matrix from sklearn.metrics.confusion_matrix      target_names: given classification classes such as [0, 1, 2]                   the class names, for example: ['high', 'medium', 'low']      title:        the text to display at the top of the matrix      cmap:         the gradient of the values displayed from matplotlib.pyplot.cm                   see http://matplotlib.org/examples/color/colormaps_reference.html                   plt.get_cmap('jet') or plt.cm.Blues      normalize:    If False, plot the raw numbers                   If True, plot the proportions      Usage     -----     plot_confusion_matrix(cm           = cm,                  # confusion matrix created by                                                               # sklearn.metrics.confusion_matrix                           normalize    = True,                # show proportions                           target_names = y_labels_vals,       # list of names of the classes                           title        = best_estimator_name) # title of graph      Citiation     ---------     http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html      """     import matplotlib.pyplot as plt     import numpy as np     import itertools      accuracy = np.trace(cm) / float(np.sum(cm))     misclass = 1 - accuracy      if cmap is None:         cmap = plt.get_cmap('Blues')      plt.figure(figsize=(8, 6))     plt.imshow(cm, interpolation='nearest', cmap=cmap)     plt.title(title)     plt.colorbar()      if target_names is not None:         tick_marks = np.arange(len(target_names))         plt.xticks(tick_marks, target_names, rotation=45)         plt.yticks(tick_marks, target_names)      if normalize:         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]       thresh = cm.max() / 1.5 if normalize else cm.max() / 2     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):         if normalize:             plt.text(j, i, "{:0.4f}".format(cm[i, j]),                      horizontalalignment="center",                      color="white" if cm[i, j] > thresh else "black")         else:             plt.text(j, i, "{:,}".format(cm[i, j]),                      horizontalalignment="center",                      color="white" if cm[i, j] > thresh else "black")       plt.tight_layout()     plt.ylabel('True label')     plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))     plt.show() 

It will look like this



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!