Nesting or combining matplotlib figures and plots?

天大地大妈咪最大 提交于 2021-01-29 16:11:29

问题


I have a function that takes an arbitrary length 3D data set of dates, prices(float), and some resulting value(float) and makes a set of seaborn heatmaps split by year. The pseudocode is as follows (note the number of years varies by dataset so I need it to arbitrarily scale):

def makePlots(data):
   split data by year
   fig,axs=plt.subplots(1, numYears)
   x=0
   for year in years
      sns.heatmap(data[year], ax = axs[x++])

   return axs

this outputs a single matplotlib figure with a heatmap for each year next to each other on a single line, as shown in this example: single plotted dataset

Now I have a higher level function in which I feed two data sets (each arbitrary amount of years) and have it print the heatmap plots for each above one another for comparison. I would like it to somehow take the figures made by the makePlots method and just stack them on top of one another, as in this example: two plotted datasets

def compareData(data1,data2):
   fig1 = makePlots(data1)
   fig2 = makePlots(data2)
   fig, (ax1,ax2) = plt.subplots(2,1)
   ax1 = fig1
   ax2 = fig2
   plt.show()

Now this code works, however not as intended. It opens up 3 new plot windows, one with data1 plotted correctly, one with data2 plotted correctly, and one with an empty 2 row subplot. Is there any way to nest the makePlots plots within a new subplot one on top of the other? I have also tried returning plt.gcf(). All the other answers on stack overflow depend on passing the axes to the plot method but given that I have an arbitrary amount of axes (years) per dataset and eventually would like to compare an arbitrary amount of datasets, this seems not ideal (not that I can figure out an implementation of that anyways since each row can have an arbitrary amount of years).


回答1:


I wouldn't recommend it but you can add subplots incrementally by using fig.add_subplot(nrow, ncol, index).

So your two functions would look something like this:

def compareData(data1, data2):
    fig = plt.figure()
    makePlots(data1, row=0, fig=fig)
    makePlots(data2, row=1, fig=fig)

def makePlots(data, row, fig):
    years = ... # parse data here
    for ii, year in enumerate(years):
        ax = fig.add_subplot(2, len(years), row * len(years) + ii + 1)
        sns.heatmap(data[year], ax=ax)

This hopefully addresses your question.


However, you are only having this problem because your are mixing data parsing and plotting in the same function. My advice would be to first parse the data, then pass the new data structure into some plotting functions.



来源:https://stackoverflow.com/questions/61839595/nesting-or-combining-matplotlib-figures-and-plots

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!