Plot gradient arrows over heatmap with plt

后端 未结 1 1094
悲哀的现实
悲哀的现实 2020-12-18 16:19

I am trying to plot arrows to visualize the gradient over a heatmap. This is the code I have until now:

import matplotlib.pyplot as plt
import numpy as np
fu         


        
相关标签:
1条回答
  • 2020-12-18 17:13
    1. It looks as though np.gradient() returns y-values before x-values
    2. The colors also appeared to be incorrect because the graphing context y-values were reversed. Thus I used np.flip(result_matrix,0) during plotting
    3. Finally, I noticed that there was a glitch when plotting the arrows when the stepsize did not divide the region evenly, in addition the mesh was not aligned to the center of the boxes. I have fixed both of these in the following code:

    Here is the code which I used to generate the graph:

    import matplotlib.pyplot as plt
    import numpy as np
    import math
    function_to_plot = lambda x, y: x**2 + y**2
    horizontal_min, horizontal_max, horizontal_stepsize = -2, 3, 0.3
    vertical_min, vertical_max, vertical_stepsize = -1, 4, 0.5
    
    horizontal_dist = horizontal_max-horizontal_min
    vertical_dist = vertical_max-vertical_min
    
    horizontal_stepsize = horizontal_dist / float(math.ceil(horizontal_dist/float(horizontal_stepsize)))
    vertical_stepsize = vertical_dist / float(math.ceil(vertical_dist/float(vertical_stepsize)))
    
    xv, yv = np.meshgrid(np.arange(horizontal_min, horizontal_max, horizontal_stepsize),
                         np.arange(vertical_min, vertical_max, vertical_stepsize))
    xv+=horizontal_stepsize/2.0
    yv+=vertical_stepsize/2.0
    
    result_matrix = function_to_plot(xv, yv)
    yd, xd = np.gradient(result_matrix)
    
    def func_to_vectorize(x, y, dx, dy, scaling=0.01):
        plt.arrow(x, y, dx*scaling, dy*scaling, fc="k", ec="k", head_width=0.06, head_length=0.1)
    
    vectorized_arrow_drawing = np.vectorize(func_to_vectorize)
    
    plt.imshow(np.flip(result_matrix,0), extent=[horizontal_min, horizontal_max, vertical_min, vertical_max])
    vectorized_arrow_drawing(xv, yv, xd, yd, 0.1)
    plt.colorbar()
    plt.show()
    
    0 讨论(0)
提交回复
热议问题