matplotlib: get the subplot layout?

天涯浪子 提交于 2021-01-28 02:58:15

问题


I have a function that creates a grid of similar 2D histograms. So that I can select whether to put this new plot on a pre-existing figure, I do the following:

def make_hist2d(x, y, current_fig=False, layout=(1,1,1),*args):

    if current_fig: 

        fig = _plt.gcf()
        ax  = fig.add_subplot(*layout)  # layout=(nrows, ncols, nplot)

    else:

        fig, ax = _plt.subplots()    


    H, x, y = np.histogram2d(...)

    # manipulate the histogram, e.g. column normalize.

    XX, YY = _np.meshgrid(xedges, yedges)
    Image  = ax.pcolormesh(XX, YY, Hplot.T, norm=norm, **pcmesh_kwargs)
    ax.autoscale(tight=True)

    grid_kargs = {'orientation': 'vertical'}
    cax, kw    = _mpl.colorbar.make_axes_gridspec(ax, **grid_kargs)
    cbar       = fig.colorbar(Image, cax=cax)
    cbar.set_label(cbar_title)

    return fig, ax, cbar



def hist2d_grid(data_dict, key_pairs, layout, *args):  # ``*args`` are things like xlog, ylog, xlabel, etc. 
                                                       # that are common to all subplots in the figure.

    fig, ax = _plt.subplots()    

    nplots = range(len(key_pairs) + 1)    # key_pairs = ((k1a, k1b), (k2a, k2b), ..., (kna, knb))

    ax_list = []

    for pair, i in zip(key_pairs, nplots):

        fig, ax, cbar = make_hist2d(data[k1a], data[k1b]

        ax_list.append(ax)

    return fig, ax_list

Then I call something like:

hgrid = hist2d_grid(...)

However, if I want to add a new figure to the grid, I don't know of a good way to get the subplot layout. For example, is there something like:

layout = fig.get_layout()

That would give me something like (nrows, ncols, n_subplots)?

I could do this with something like:

n_plot = len(ax_list) / 2  # Each subplot generates a plot and a color bar.
n_rows = np.floor(np.sqrt(n_ax))
n_cols = np.ceil(np.sqrt(n_ax))

But I have to deal with special cases like a (2,4) subplot array for which I would get n_rows = 2 and n_cols = 3, which means that I would be passing (2,3,8) to ax.add_subplot(), which clearly doesn't work because 8 > 3*2.


回答1:


As ax returned by fig, ax = plt.subplots(4,2) is a numpy array of axes, then ax.shape will give you the layout information you want, e.g.

 nrows, ncols = ax.shape
 n_subplots = nrows*ncols

You can also get the locations of the various axes by looping over the children of the figure object,

[[f.colNum, f.rowNum] for f in fig.get_children()[1:]]

and probably get the size from the final element fig.get_children()[-1]

You could also use gridspec to be more explicit about the location of subplots if needed. With gridspec you setup the gridspec object and pass to subplot,

import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(2, 2)
ax = plt.subplot(gs[0, 0])

To get the layout you can then use,

gs.get_geometry()


来源:https://stackoverflow.com/questions/32960832/matplotlib-get-the-subplot-layout

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!