Single row (or column) heat map in python

戏子无情 提交于 2019-12-08 05:15:53

问题


I can create and n by n heatmap using the following code, for example let n be 10:

random_matrix = np.random.rand(10,10)
number = 10
incrmnt = 1.0
x = list(range(1,number +1))
plt.pcolormesh(x, x, random_matrix)
plt.colorbar() 
plt.xlim(1, number)
plt.xlabel('Number 1')
plt.ylim(1, number)
plt.ylabel('Number 2')
plt.tick_params(
    axis = 'both',
    which = 'both',
    bottom = 'off',
    top = 'off', 
    labelbottom = 'off', 
    right = 'off',
    left = 'off',
    labelleft = 'off')

I would like to add a 2 row heatmap one near each of the x and y axis, from say row1 = np.random.rand(1,10)and col1 = np.random.rand(1,10). Here is an example image of what I would like to produce:

Thanks in advance.


回答1:


You would create a subplot grid where the width- and height ratios between the subplots correspond to the number of pixels in the respective dimension. You can then add respective plots to those subplots. In the code below I used an imshow plot, because I find it more intuitive to have one pixel per item in the array (instead of one less).

In order to have the colorbar represent the colors accross the different subplots, one can use a matplotlib.colors.Normalize instance, which is provided to each of the subplots, as well as the manually created ScalarMappable for the colorbar.

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

m = np.random.rand(10,10)
x = np.random.rand(1,m.shape[1])
y = np.random.rand(m.shape[0],1)

norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
grid = dict(height_ratios=[1, m.shape[0]], width_ratios=[1,m.shape[0], 0.5 ])
fig, axes = plt.subplots(ncols=3, nrows=2, gridspec_kw = grid)

axes[1,1].imshow(m, aspect="auto", cmap="viridis", norm=norm)
axes[0,1].imshow(x, aspect="auto", cmap="viridis", norm=norm)
axes[1,0].imshow(y, aspect="auto", cmap="viridis", norm=norm)

axes[0,0].axis("off")
axes[0,2].axis("off")

axes[1,1].set_xlabel('Number 1')
axes[1,1].set_ylabel('Number 2')
for ax in [axes[1,1], axes[0,1], axes[1,0]]:
    ax.set_xticks([]); ax.set_yticks([])

sm = matplotlib.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])

fig.colorbar(sm, cax=axes[1,2]) 

plt.show()


来源:https://stackoverflow.com/questions/43076488/single-row-or-column-heat-map-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!