Seaborn Confusion Matrix (heatmap) 2 color schemes (correct diagonal vs wrong rest)

谁说我不能喝 提交于 2021-02-04 16:43:12

问题


Background

In a confusion matrix, the diagonal represents the cases that the predicted label matches the correct label. So the diagonal is good, while all other cells are bad. To clarify what is good and what is bad in a CM for non-experts, I want to give the diagonal a different color than the rest. I want to achieve this with Python & Seaborn.

Basically I'm trying to achieve what this question does in R (ggplot2 Heatmap 2 Different Color Schemes - Confusion Matrix: Matches in Different Color Scheme than Missclassifications)

Normal Seaborn Confusion Matrix with heatmap

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

sns.heatmap(cf_matrix, annot=True, cmap='Blues')  # cmap='OrRd'

Which results in this image:

Goal

I would like to color the non-diagonal cells with e.g. cmap='OrRd'. So I imagine there would be 2 colorbars, 1 blue for the diagonal and 1 for the other cells. Preferably the values of both colorbars match (so both e.g. 0-70 and not 0-70 and 0-40). How would I approach this?

The following is not made with code, but with photo editing software:


回答1:


You can use mask= in the call to heatmap() to choose which cells to show. Using two different masks for the diagonal and the off_diagonal cells, you can get the desired output:

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

vmin = np.min(cf_matrix)
vmax = np.max(cf_matrix)
off_diag_mask = np.eye(*cf_matrix.shape, dtype=bool)

fig = plt.figure()
sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, cbar_kws=dict(ticks=[]))

If you want to get fancy, you can create the axes using GridSpec to have a better layout:

import numpy as np import seaborn as sns

fig = plt.figure()
gs0 = matplotlib.gridspec.GridSpec(1,2, width_ratios=[20,2], hspace=0.05)
gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,2, subplot_spec=gs0[1], hspace=0)

ax = fig.add_subplot(gs0[0])
cax1 = fig.add_subplot(gs00[0])
cax2 = fig.add_subplot(gs00[1])

sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax2)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax1, cbar_kws=dict(ticks=[]))




回答2:


You could first plot the heatmap with colormap 'OrRd' and then overlay it with a heatmap with colormap 'Blues', with the upper and lower triangle values replaced with NaN's, see the following example:

def diagonal_heatmap(m):

    vmin = np.min(m)
    vmax = np.max(m)    
    
    sns.heatmap(cf_matrix, annot=True, cmap='OrRd', vmin=vmin, vmax=vmax)

    diag_nan = np.full_like(m, np.nan, dtype=float)
    np.fill_diagonal(diag_nan, np.diag(m))
    
    sns.heatmap(diag_nan, annot=True, cmap='Blues', vmin=vmin, vmax=vmax, cbar_kws={'ticks':[]}) 




cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

diagonal_heatmap(cf_matrix)


来源:https://stackoverflow.com/questions/64800003/seaborn-confusion-matrix-heatmap-2-color-schemes-correct-diagonal-vs-wrong-re

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!