pandas - multi index plotting

99封情书 提交于 2019-11-27 03:54:17

I would use a factor plot from seaborn.

Say you have data like this:

import numpy as np
import pandas

import seaborn
seaborn.set(style='ticks') 
np.random.seed(0)

groups = ('Group 1', 'Group 2')
sexes = ('Male', 'Female')
means = ('Low', 'High')
index = pandas.MultiIndex.from_product(
    [groups, sexes, means], 
   names=['Group', 'Sex', 'Mean']
)

values = np.random.randint(low=20, high=100, size=len(index))
data = pandas.DataFrame(data={'val': values}, index=index).reset_index()
print(data)

     Group     Sex  Mean  val
0  Group 1    Male   Low   64
1  Group 1    Male  High   67
2  Group 1  Female   Low   84
3  Group 1  Female  High   87
4  Group 2    Male   Low   87
5  Group 2    Male  High   29
6  Group 2  Female   Low   41
7  Group 2  Female  High   56

You can then create the factor plot with one command + plus an extra line to remove some redundant (for your data) x-labels:

fg = seaborn.factorplot(x='Group', y='val', hue='Mean', 
                        col='Sex', data=data, kind='bar')
fg.set_xlabels('')

Which gives me:

Ramon Crehuet

In a related question I found an alternative solution by @Stein that codes the multiindex levels as different labels. Here is how it looks like for your example:

import pandas as pd
import matplotlib.pyplot as plt
from itertools import groupby
import numpy as np 
%matplotlib inline

groups = ('Group 1', 'Group 2')
sexes = ('Male', 'Female')
means = ('Low', 'High')
index = pd.MultiIndex.from_product(
    [groups, sexes, means], 
   names=['Group', 'Sex', 'Mean']
)

values = np.random.randint(low=20, high=100, size=len(index))
data = pd.DataFrame(data={'val': values}, index=index)
# unstack last level to plot two separate columns
data = data.unstack(level=-1)

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='gray')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

ax = data['val'].plot(kind='bar')
#Below 2 lines remove default labels
ax.set_xticklabels('')
ax.set_xlabel('')
label_group_bar_table(ax, data)

This gives:

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!