Matplotlib - Dynamic (bar) chart height based on data?

后端 未结 1 678
走了就别回头了
走了就别回头了 2020-12-21 08:03

After struggling with matplotlib for longer than I\'d like to admit by trying to do something that\'s a breeze in pretty much any other plotting library I ever used, I\'ve d

相关标签:
1条回答
  • 2020-12-21 08:49

    I think the only way to have at the same time equal bar width (width in vertical direction) and differing subplotsizes is really to manually position the axes in the figure.

    To this end you can specify how large in inches a bar should be and how much spacing you want to have between the subplots in units of this bar width. From those numbers together with the amount of data to plot, you can calculate the total figure height in inches. Each of the subplots is then positioned (via fig.add_axes) according to the amount of data and the spacing in the previous subplots. Thereby you nicely fill up the plot. Adding a new set of data will then make the figure larger.

    data = [
        {"name": "Category 1", "entries": [
            {"name": "Entry 1", "value": 5},
            {"name": "Entry 2", "value": 2},
        ]},
        {"name": "Category 2", "entries": [
            {"name": "Entry 1", "value": 1},
        ]},
        {"name": "Category 3", "entries": [
            {"name": "Entry 1", "value": 1},
            {"name": "Entry 2", "value": 10},
            {"name": "Entry 3", "value": 4},
        ]}, 
        {"name": "Category 4", "entries": [
            {"name": "Entry 1", "value": 6},
        ]},
    ]
    
    import matplotlib.pyplot as plt
    import numpy as np
    
    def plot_data(data,
                  barwidth = 0.2, # inch per bar
                  spacing = 3,    # spacing between subplots in units of barwidth
                  figx = 5,       # figure width in inch
                  left = 4,       # left margin in units of bar width
                  right=2):       # right margin in units of bar width
    
        tc = len(data) # "total_categories", holds how many charts to create
        max_values = []  # holds the maximum number of bars to create
        for category in data:
            max_values.append( len(category["entries"]))
        max_values = np.array(max_values)
    
        # total figure height:
        figy = ((np.sum(max_values)+tc) + (tc+1)*spacing)*barwidth #inch
    
        fig = plt.figure(figsize=(figx,figy))
        ax = None
        for index, category in enumerate(data):
            entries = []
            values = []
            for entry in category["entries"]:
                entries.append(entry["name"])
                values.append(entry["value"])
            if not entries:
                continue  # do not create empty charts
            y_ticks = range(1, len(entries) + 1)
            # coordinates of new axes [left, bottom, width, height]
            coord = [left*barwidth/figx, 
                     1-barwidth*((index+1)*spacing+np.sum(max_values[:index+1])+index+1)/figy,  
                     1-(left+right)*barwidth/figx,  
                     (max_values[index]+1)*barwidth/figy ] 
    
            ax = fig.add_axes(coord, sharex=ax)
            ax.barh(y_ticks, values)
            ax.set_ylim(0, max_values[index] + 1)  # limit the y axis for fixed height
            ax.set_yticks(y_ticks)
            ax.set_yticklabels(entries)
            ax.invert_yaxis()
            ax.set_title(category["name"], loc="left")
    
    
    plot_data(data)
    plt.savefig(__file__+".png")
    plt.show()
    

    0 讨论(0)
提交回复
热议问题