问题
I can create and n by n heatmap using the following code, for example let n be 10:
random_matrix = np.random.rand(10,10)
number = 10
incrmnt = 1.0
x = list(range(1,number +1))
plt.pcolormesh(x, x, random_matrix)
plt.colorbar()
plt.xlim(1, number)
plt.xlabel('Number 1')
plt.ylim(1, number)
plt.ylabel('Number 2')
plt.tick_params(
axis = 'both',
which = 'both',
bottom = 'off',
top = 'off',
labelbottom = 'off',
right = 'off',
left = 'off',
labelleft = 'off')
I would like to add a 2 row heatmap one near each of the x and y axis, from say row1 = np.random.rand(1,10)
and col1 = np.random.rand(1,10)
.
Here is an example image of what I would like to produce:
Thanks in advance.
回答1:
You would create a subplot grid where the width- and height ratios between the subplots correspond to the number of pixels in the respective dimension. You can then add respective plots to those subplots. In the code below I used an imshow
plot, because I find it more intuitive to have one pixel per item in the array (instead of one less).
In order to have the colorbar represent the colors accross the different subplots, one can use a matplotlib.colors.Normalize
instance, which is provided to each of the subplots, as well as the manually created ScalarMappable for the colorbar.
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
m = np.random.rand(10,10)
x = np.random.rand(1,m.shape[1])
y = np.random.rand(m.shape[0],1)
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
grid = dict(height_ratios=[1, m.shape[0]], width_ratios=[1,m.shape[0], 0.5 ])
fig, axes = plt.subplots(ncols=3, nrows=2, gridspec_kw = grid)
axes[1,1].imshow(m, aspect="auto", cmap="viridis", norm=norm)
axes[0,1].imshow(x, aspect="auto", cmap="viridis", norm=norm)
axes[1,0].imshow(y, aspect="auto", cmap="viridis", norm=norm)
axes[0,0].axis("off")
axes[0,2].axis("off")
axes[1,1].set_xlabel('Number 1')
axes[1,1].set_ylabel('Number 2')
for ax in [axes[1,1], axes[0,1], axes[1,0]]:
ax.set_xticks([]); ax.set_yticks([])
sm = matplotlib.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
fig.colorbar(sm, cax=axes[1,2])
plt.show()
来源:https://stackoverflow.com/questions/43076488/single-row-or-column-heat-map-in-python