Heatmap with matplotlib

风格不统一 提交于 2020-01-14 03:14:10

问题


I have a set of 3-tuples, each 3-tuple consists of (var1, var2, result). Example of a list of 9 3-tuples:

<type 'list'>: 
[(4, 0.7, 0.8530612244898, 0.016579670213527985), 
 (4, 0.6, 0.8730158730157779, 0.011562402525241757), 
 (6, 0.8, 0.8378684807257778, 0.018175985875060037), 
 (4, 0.8, 0.8605442176870833, 0.015586992159716321), 
 (2, 0.8, 0.8537414965986667, 0.0034013605443334316), 
 (2, 0.7, 0.843537414966, 0.006802721088333352), 
 (6, 0.6, 0.8480725623582223, 0.01696896774039503), 
 (2, 0.6, 0.84693877551025, 0.010204081632749995), 
 (6, 0.7, 0.8577097505669444, 0.019873637350220318)]

Now I'd like to create a heatmap out of this. Var1 can be values [2,4,6] and Var2 can be [0.6, 0.7, 0.8]. There's a total of 9 results.

This is the code I use to plot a heatmap:

    # list of 3-tuples to 3 lists: x, y and weights
    # x (var1) = [2,4,6]
    # y (var2) = [0.6, 0.7, 0.8]
    # weights (res) = [....] (9 values)

    x, y = np.meshgrid(x, y)
    intensity = np.array(weights)

    plt.pcolormesh(x, y, intensity)
    plt.colorbar()  # need a colorbar to show the intensity scale
    plt.show()

Which leads to this graph:

What I find weird is that there's only 4 sections, but I expected there to be 9 (3x3). Can someone shed some light on what I did wrong here?


回答1:


meshgrid is creating nine points, not nine patches.
Look at the output. You've got one point at (2, 0.6), one at (2, 0.7), etc.




回答2:


To create 9 patches you could set x and y to the vertices of the patches:

x = [1,3,5,7]
y = [0.55,0.65,0.75,0.85]

x, y = np.meshgrid(x, y)
intensity = np.random.random(size=(3,3))

plt.pcolormesh(x, y, intensity)
plt.colorbar()  # need a colorbar to show the intensity scale
plt.show()




回答3:


You would first need to sort your array to be able to later get the correct matrix out of it. This can be done using numpy.lexsort.

pcolormesh

The reason for there being one row and column less can be found in this question: Can someone explain this matplotlib pcolormesh quirk?

So you need to decide if the values in the first two columns should denote the edges of the rectangles in the plot or if they are the center. In any case you need one value more than you have colored rectangles to plot the matrix as pcolormesh.

data = [ (4, 0.7, 0.8530612244898, 0.016579670213527985), 
         (4, 0.6, 0.8730158730157779, 0.011562402525241757), 
         (6, 0.8, 0.8378684807257778, 0.018175985875060037), 
         (4, 0.8, 0.8605442176870833, 0.015586992159716321), 
         (2, 0.8, 0.8537414965986667, 0.0034013605443334316), 
         (2, 0.7, 0.843537414966, 0.006802721088333352), 
         (6, 0.6, 0.8480725623582223, 0.01696896774039503), 
         (2, 0.6, 0.84693877551025, 0.010204081632749995), 
         (6, 0.7, 0.8577097505669444, 0.019873637350220318)]

import numpy as np
import matplotlib.pyplot as plt

# sort the array
data=np.array(data)
ind = np.lexsort((data[:,0],data[:,1]))    
data = data[ind]

#create meshgrid for x and y
xu = np.unique(data[:,0])
yu = np.unique(data[:,1])
# if values are centers of rectangles:
x = np.append(xu , [xu[-1]+np.diff(xu)[-1]])-np.diff(xu)[-1]/2.
y = np.append(yu , [yu[-1]+np.diff(yu)[-1]])-np.diff(yu)[-1]/2.
# if values are edges of rectanges:
# x = np.append(xu , [xu[-1]+np.diff(xu)[-1]])
# y = np.append(yu , [yu[-1]+np.diff(yu)[-1]])
X,Y = np.meshgrid(x,y)

#reshape third column to match 
Z = data[:,2].reshape(3,3)

plt.pcolormesh(X,Y,Z, cmap="jet")
plt.colorbar()

plt.show()

imshow

The same plot can be optained using imshow, where you wouldn't need a grid, but rather specify the extent of the plot.

import numpy as np
import matplotlib.pyplot as plt

# sort the array
data=np.array(data)
ind = np.lexsort((data[:,0],data[:,1]))    
data = data[ind]

#create meshgrid for x and y
xu = np.unique(data[:,0])
yu = np.unique(data[:,1])
x = np.append(xu , [xu[-1]+np.diff(xu)[-1]])-np.diff(xu)[-1]/2.
y = np.append(yu , [yu[-1]+np.diff(yu)[-1]])-np.diff(yu)[-1]/2.


#reshape third column to match 
Z = data[:,2].reshape(3,3)

plt.imshow(Z, extent=[x[0],x[-1],y[0],y[-1]], cmap="jet", 
           aspect="auto", origin="lower")
plt.colorbar()

plt.show()


来源:https://stackoverflow.com/questions/44987972/heatmap-with-matplotlib

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!