How to apply OLS from statsmodels to groupby

跟風遠走 提交于 2019-12-04 02:03:58

问题


I am running OLS on products by month. While this works fine for a single product, my dataframe contains many products. If I create a groupby object OLS gives an error.

linear_regression_df:
  product_desc  period_num    TOTALS  
0    product_a     1          53  
3    product_a     2          52 
6    product_a     3          50 
1    product_b     1          44 
4    product_b     2          43 
7    product_b     3          41 
2    product_c     1          36   
5    product_c     2          35 
8    product_c     3          34 


from pandas import DataFrame, Series
import statsmodels.api as sm    

linear_regression_grouped = linear_regression_df.groupby(['product_desc'])
X = linear_regression_grouped['period_num'] 
y = linear_regression_grouped['TOTALS']

model = sm.OLS(y, X)
results = model.fit()

And I get this error on the sm.OLS() line:

ValueError: unrecognized data structures: <class 'pandas.core.groupby.SeriesGroupBy'>

So how can I go through my dataframe and apply sm.OLS() for each product_desc?


回答1:


You could do something like this ...

import pandas as pd
import statsmodels.api as sm

for products in linear_regression_df.product_desc.unique():
    tempdf = linear_regression_df[linear_regression_df.product_desc == products]
    X = tempdf['period_num']
    y = tempdf['TOTALS']

    model = sm.OLS(y, X)
    results = model.fit()

    print results.params #  Or whatever summary info you want



回答2:


Use get_group to get each individual group and perform OLS model on each one:

for group in linear_regression_grouped.groups.keys():
    df= linear_regression_grouped.get_group(group)
    X = df['period_num'] 
    y = df['TOTALS']
    model = sm.OLS(y, X)
    results = model.fit()
    print results.summary()

But in real case, you also want to have the intercept term so the model should be defined slightly differently:

for group in linear_regression_grouped.groups.keys():
    df= linear_regression_grouped.get_group(group)
    df['constant']=1
    X = df[['period_num','constant']]
    y = df['TOTALS']
    model = sm.OLS(y,X)
    results = model.fit()
    print results.summary()

The results (with intercept and without) are, certainly, very different.



来源:https://stackoverflow.com/questions/24088439/how-to-apply-ols-from-statsmodels-to-groupby

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!