How to express classes on the axis of a heatmap in Seaborn

前端 未结 3 1185
走了就别回头了
走了就别回头了 2020-12-01 02:13

I created a very simple heatmap chart with Seaborn displaying a similarity square matrix. Here is the one line of code I used:

sns.heatmap(sim_mat, linewidth         


        
3条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-12-01 02:53

    Building on the above answer, I think it's worth noting the possibility of multiple colour levels for labels - as noted in the clustermap docs ({row,col}_colors). I couldn't find an example of multiple levels, so I thought I'd share an example here.

    networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
    

    network level

    network_labels = networks.columns.get_level_values("network")
    network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
    network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
    

    Create index using the columns for networks

    network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
    

    node level

    node_labels = networks.columns.get_level_values("node")
    node_pal = sns.cubehelix_palette(node_labels.unique().size)
    node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
    

    Create index using the columns for nodes

    node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
    

    Create dataframe for row and column color levels

    network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors))
    

    create clustermap

    g = sns.clustermap(networks.corr(),
    # Turn off the clustering
    row_cluster=False, col_cluster=False,
    # Add colored class labels using data frame created from node and network colors
    row_colors = network_node_colors,
    col_colors = network_node_colors,
    # Make the plot look better when many rows/cols
    linewidths=0,
    xticklabels=False, yticklabels=False,
    center=0, cmap="vlag")
    

    create two legends - one for each level by creating invisible column and row barplots (as per above)

    network legend

    from matplotlib.pyplot import gcf
    
    for label in network_labels.unique():
        g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0)
    
    l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.47, 0.8), bbox_transform=gcf().transFigure)
    

    node legend

    for label in node_labels.unique():
        g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0)
    
    l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.8, 0.8), bbox_transform=gcf().transFigure)
    
    plt.show()
    

提交回复
热议问题