In matplotlib.pyplot, how to plot two datasets using interleaved bar graphs?

后端 未结 4 1776
一向
一向 2021-01-02 02:07

I am going through Think Stats and I would like to compare multiple data sets visually. I can see from the book examples that it is possible to generate an interleaved bar g

4条回答
  •  我在风中等你
    2021-01-02 02:37

    I came across this problem a while ago and created a wrapper function that takes a 2D array and automatically creates a multi-barchart from it:

    Multi-category bar chart

    The code:

    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import operator as o
    
    import numpy as np
    
    dpoints = np.array([['rosetta', '1mfq', 9.97],
               ['rosetta', '1gid', 27.31],
               ['rosetta', '1y26', 5.77],
               ['rnacomposer', '1mfq', 5.55],
               ['rnacomposer', '1gid', 37.74],
               ['rnacomposer', '1y26', 5.77],
               ['random', '1mfq', 10.32],
               ['random', '1gid', 31.46],
               ['random', '1y26', 18.16]])
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    def barplot(ax, dpoints):
        '''
        Create a barchart for data across different categories with
        multiple conditions for each category.
    
        @param ax: The plotting axes from matplotlib.
        @param dpoints: The data set as an (n, 3) numpy array
        '''
    
        # Aggregate the conditions and the categories according to their
        # mean values
        conditions = [(c, np.mean(dpoints[dpoints[:,0] == c][:,2].astype(float))) 
                      for c in np.unique(dpoints[:,0])]
        categories = [(c, np.mean(dpoints[dpoints[:,1] == c][:,2].astype(float))) 
                      for c in np.unique(dpoints[:,1])]
    
        # sort the conditions, categories and data so that the bars in
        # the plot will be ordered by category and condition
        conditions = [c[0] for c in sorted(conditions, key=o.itemgetter(1))]
        categories = [c[0] for c in sorted(categories, key=o.itemgetter(1))]
    
        dpoints = np.array(sorted(dpoints, key=lambda x: categories.index(x[1])))
    
        # the space between each set of bars
        space = 0.3
        n = len(conditions)
        width = (1 - space) / (len(conditions))
    
        # Create a set of bars at each position
        for i,cond in enumerate(conditions):
            indeces = range(1, len(categories)+1)
            vals = dpoints[dpoints[:,0] == cond][:,2].astype(np.float)
            pos = [j - (1 - space) / 2. + i * width for j in indeces]
            ax.bar(pos, vals, width=width, label=cond, 
                   color=cm.Accent(float(i) / n))
    
        # Set the x-axis tick labels to be equal to the categories
        ax.set_xticks(indeces)
        ax.set_xticklabels(categories)
        plt.setp(plt.xticks()[1], rotation=90)
    
        # Add the axis labels
        ax.set_ylabel("RMSD")
        ax.set_xlabel("Structure")
    
        # Add a legend
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles[::-1], labels[::-1], loc='upper left')
    
    barplot(ax, dpoints)
    plt.show()
    

    If you're interested in what this function does and the logic behind it, here's a (shamelessly self-promoting) link to the blog post describing it.

提交回复
热议问题