Variance Inflation Factor in Python

前端 未结 8 595
星月不相逢
星月不相逢 2020-12-22 23:04

I\'m trying to calculate the variance inflation factor (VIF) for each column in a simple dataset in python:

a b c d
1 2 4 4
1 2 6 3
2 3 7 4
3 2 8 5
4 1 9 4
         


        
8条回答
  •  抹茶落季
    2020-12-22 23:28

    I wrote this function based on some other posts I saw on Stack and CrossValidated. It shows the features which are over the threshold and returns a new dataframe with the features removed.

    from statsmodels.stats.outliers_influence import variance_inflation_factor 
    from statsmodels.tools.tools import add_constant
    
    def calculate_vif_(df, thresh=5):
        '''
        Calculates VIF each feature in a pandas dataframe
        A constant must be added to variance_inflation_factor or the results will be incorrect
    
        :param df: the pandas dataframe containing only the predictor features, not the response variable
        :param thresh: the max VIF value before the feature is removed from the dataframe
        :return: dataframe with features removed
        '''
        const = add_constant(df)
        cols = const.columns
        variables = np.arange(const.shape[1])
        vif_df = pd.Series([variance_inflation_factor(const.values, i) 
                   for i in range(const.shape[1])], 
                  index=const.columns).to_frame()
    
        vif_df = vif_df.sort_values(by=0, ascending=False).rename(columns={0: 'VIF'})
        vif_df = vif_df.drop('const')
        vif_df = vif_df[vif_df['VIF'] > thresh]
    
        print 'Features above VIF threshold:\n'
        print vif_df[vif_df['VIF'] > thresh]
    
        col_to_drop = list(vif_df.index)
    
        for i in col_to_drop:
            print 'Dropping: {}'.format(i)
            df = df.drop(columns=i)
    
        return df
    

提交回复
热议问题