Using Pandas crosstab with seaborn stacked barplots

前端 未结 2 926
被撕碎了的回忆
被撕碎了的回忆 2020-12-15 14:08

I am trying to create a stacked barplot in seaborn with my dataframe.

I have first generated a crosstab table in pandas like so:

pd.crosstab(df[\'Pe         


        
相关标签:
2条回答
  • 2020-12-15 14:19

    The guy who created Seaborn doesn't like stacked bar charts (but that link has a hack which uses Seaborn + Matplotlib to make them anyway).

    If you're willing to accept a grouped bar chart instead of a stacked one, here's one approach:

     # first some sample data
     import numpy as np 
     import pandas as pd
     import seaborn as sns
    
     N = 1000
     mark = np.random.choice([True,False], N)
     periods = np.random.choice(['BASELINE','WEEK 12', 'WEEK 24', 'WEEK 4'], N)
    
     df = pd.DataFrame({'mark':mark,'period':periods})
     ct = pd.crosstab(df.period, df.mark)
    
     mark      False  True 
     period                
     BASELINE    118    111
     WEEK 12     117    149
     WEEK 24     117    130
     WEEK 4      127    131
    
     # now stack and reset
     stacked = ct.stack().reset_index().rename(columns={0:'value'})
    
     # plot grouped bar chart
     sns.barplot(x=stacked.period, y=stacked.value, hue=stacked.mark)
    

    0 讨论(0)
  • 2020-12-15 14:36

    As you said you can use pandas to create the stacked bar plot. The argument that you want to have a "seaborn plot" is irrelevant, since every seaborn plot and every pandas plot are in the end simply matplotlib objects, as the plotting tools of both libraries are merely matplotlib wrappers.

    So here is a complete solution (taking the datacreation from @andrew_reece's answer).

    import numpy as np 
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    n = 500
    mark = np.random.choice([True,False], n)
    periods = np.random.choice(['BASELINE','WEEK 12', 'WEEK 24', 'WEEK 4'], n)
    
    df = pd.DataFrame({'mark':mark,'period':periods})
    ct = pd.crosstab(df.period, df.mark)
    
    ct.plot.bar(stacked=True)
    plt.legend(title='mark')
    
    plt.show()
    

    0 讨论(0)
提交回复
热议问题