How to check if colorbar exists on figure

前端 未结 4 1658
遥遥无期
遥遥无期 2020-12-18 03:33

Question: Is there a way to check if a color bar already exists?

I am making many plots with a loop. The issue is that the color bar is drawn every

4条回答
  •  悲&欢浪女
    2020-12-18 03:47

    Is is actually not easy to remove a colorbar from a plot and later draw a new one to it. The best solution I can come up with at the moment is the following, which assumes that there is only one axes present in the plot. Now, if there was a second axis, it must be the colorbar beeing present. So by checking how many axes we find on the plot, we can judge upon whether or not there is a colorbar.

    Here we also mind the user's wish not to reference any named objects from outside. (Which does not makes much sense, as we need to use plt anyways, but hey.. so was the question)

    import matplotlib.pyplot as plt
    import numpy as np
    
    fig, ax = plt.subplots()
    im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
    ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
    ax.set_title("Title")
    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Label")
    
    
    for i in range(10):
        # inside this loop we should not access any variables defined outside
        #   why? no real reason, but questioner asked for it.
        #draw new colormesh
        im = plt.gcf().gca().pcolormesh(np.random.rand(2,2))
        #check if there is more than one axes
        if len(plt.gcf().axes) > 1: 
            # if so, then the last axes must be the colorbar.
            # we get its extent
            pts = plt.gcf().axes[-1].get_position().get_points()
            # and its label
            label = plt.gcf().axes[-1].get_ylabel()
            # and then remove the axes
            plt.gcf().axes[-1].remove()
            # then we draw a new axes a the extents of the old one
            cax= plt.gcf().add_axes([pts[0][0],pts[0][1],pts[1][0]-pts[0][0],pts[1][1]-pts[0][1]  ])
            # and add a colorbar to it
            cbar = plt.colorbar(im, cax=cax)
            cbar.ax.set_ylabel(label)
            # unfortunately the aspect is different between the initial call to colorbar 
            #   without cax argument. Try to reset it (but still it's somehow different)
            cbar.ax.set_aspect(20)
        else:
            plt.colorbar(im)
    
    plt.show()
    

    In general a much better solution would be to operate on the objects already present in the plot and only update them with the new data. Thereby, we suppress the need to remove and add axes and find a much cleaner and faster solution.

    import matplotlib.pyplot as plt
    import numpy as np
    
    fig, ax = plt.subplots()
    im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
    ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
    ax.set_title("Title")
    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Label")
    
    
    for i in range(10):
        data = np.array(np.random.rand(2,2) )
        im.set_array(data.flatten())
        cbar.set_clim(vmin=data.min(),vmax=data.max()) 
        cbar.draw_all() 
        plt.draw()
    
    plt.show()
    


    Update:

    Actually, the latter approach of referencing objects from outside even works together with the multiprocess approach desired by the questioner.

    So, here is a code that updates the figure, without the need to delete the colorbar.

    import matplotlib.pyplot as plt
    import numpy as np
    import multiprocessing
    import time
    
    fig, ax = plt.subplots()
    im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
    ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="w", lw=6)
    ax.set_title("Title")
    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Label")
    tx = ax.text(0.2,0.8, "", fontsize=30, color="w")
    tx2 = ax.text(0.2,0.2, "", fontsize=30, color="w")
    
    def do(number):
        start = time.time()
        tx.set_text(str(number))
        data = np.array(np.random.rand(2,2)*(number+1) )
        im.set_array(data.flatten())
        cbar.set_clim(vmin=data.min(),vmax=data.max()) 
        tx2.set_text("{m:.2f} < {ma:.2f}".format(m=data.min(), ma= data.max() )) 
        cbar.draw_all() 
        plt.draw()
        plt.savefig("multiproc/{n}.png".format(n=number))
        stop = time.time()
    
        return np.array([number, start, stop])
    
    
    if __name__ == "__main__":
        multiprocessing.freeze_support()
    
        some_list = range(0,50)
        num_proc = 5
        p = multiprocessing.Pool(num_proc)
        nu = p.map(do, some_list)
        nu = np.array(nu)
    
        plt.close("all")
        fig, ax = plt.subplots(figsize=(16,9))
        ax.barh(nu[:,0], nu[:,2]-nu[:,1], height=np.ones(len(some_list)), left=nu[:,1],  align="center")
        plt.show()
    

    (The code at the end shows a timetable which allows to see that multiprocessing has indeed taken place)

提交回复
热议问题