How can I label the clusters in sns clustermap

旧时模样 提交于 2021-02-11 06:01:17

问题


I am creating a clustermap with the following code.

import numpy as np
import pandas as pd
import seaborn as sns

all_net_names  = ['early_vis', 'face', 'motion', 'scene', 'scene', 'scene', 
                  'dmn', 'dmn', 'dmn', 'dmn', 'dmn', 'dmn', 'reward', 'reward',
                  'reward', 'reward', 'reward', 'ofc', 'ofc', 'ofc', 'ofc']

roi_names = ['E', 'F', 'M', 'S1', 'S2', 'S3', 'D1', 'D2', 'D3', 'D4', 'D5',
             'D6', 'R1', 'R2', 'R3', 'R4', 'R5','O1', 'O2', 'O3', 'O4']

n_roi = len(roi_names)
M = np.random.rand(n_roi, n_roi) # array to plot

net_ind = sorted(np.unique(all_net_names, return_index=True)[1])
net_names = [all_net_names[index] for index in sorted(net_ind)]
network_pal = sns.husl_palette(len(net_names), s=.45)
network_lut = dict(zip(map(str, np.unique(all_net_names)), network_pal))
network_colors = pd.Series(all_net_names).map(network_lut)
network_colors = np.asarray(network_colors)

g = sns.clustermap(M, center=0, cmap="vlag",
                   row_cluster=False, 
                   col_cluster=False,
                   row_colors=network_colors, 
                   col_colors=network_colors,
                   linewidths=0, figsize=(10, 10))

g.ax_heatmap.set_xticklabels(roi_names, rotation=90)
g.ax_heatmap.set_yticklabels(roi_names, rotation=0)

It works and gives this output:

I could add labels corresponding to each cell but I also want to label each cluster with the unique network names as in here:

Any ideas how to achieve that?


回答1:


Maybe adding this at the end of the code?

g.ax_row_colors.set_yticks(0.5 * (np.array(net_ind) + np.array(net_ind[1:] + [len(all_net_names)])))
g.ax_row_colors.set_yticklabels(net_names)
g.ax_row_colors.yaxis.set_tick_params(size=0) # make tick marks invisible

The start of each group is given by net_ind. To put the labels nicely centered, they should be placed just in the middle of their start position and the start position of the next label. As the last group doesn't have a next label, we take the length of all_net_names as the end of the last group.

The same could be done for the columns:

g.ax_col_colors.set_xticks(0.5 * (np.array(net_ind) + np.array(net_ind[1:] + [len(all_net_names)])))
g.ax_col_colors.set_xticklabels(net_names, rotation=90)
g.ax_col_colors.xaxis.set_tick_params(size=0) # make tick marks invisible
g.ax_col_colors.xaxis.tick_top()


来源:https://stackoverflow.com/questions/59792534/how-can-i-label-the-clusters-in-sns-clustermap

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!