Impute categorical missing values in scikit-learn

后端 未结 10 1408
清歌不尽
清歌不尽 2020-11-30 16:55

I\'ve got pandas data with some columns of text type. There are some NaN values along with these text columns. What I\'m trying to do is to impute those NaN\'s by skle

10条回答
  •  春和景丽
    2020-11-30 17:27

    Inspired by the answers here and for the want of a goto Imputer for all use-cases I ended up writing this. It supports four strategies for imputation mean, mode, median, fill works on both pd.DataFrame and Pd.Series.

    mean and median works only for numeric data, mode and fill works for both numeric and categorical data.

    class CustomImputer(BaseEstimator, TransformerMixin):
        def __init__(self, strategy='mean',filler='NA'):
           self.strategy = strategy
           self.fill = filler
    
        def fit(self, X, y=None):
           if self.strategy in ['mean','median']:
               if not all(X.dtypes == np.number):
                   raise ValueError('dtypes mismatch np.number dtype is \
                                     required for '+ self.strategy)
           if self.strategy == 'mean':
               self.fill = X.mean()
           elif self.strategy == 'median':
               self.fill = X.median()
           elif self.strategy == 'mode':
               self.fill = X.mode().iloc[0]
           elif self.strategy == 'fill':
               if type(self.fill) is list and type(X) is pd.DataFrame:
                   self.fill = dict([(cname, v) for cname,v in zip(X.columns, self.fill)])
           return self
    
       def transform(self, X, y=None):
           return X.fillna(self.fill)
    

    usage

    >> df   
        MasVnrArea  FireplaceQu
    Id  
    1   196.0   NaN
    974 196.0   NaN
    21  380.0   Gd
    5   350.0   TA
    651 NaN     Gd
    
    
    >> CustomImputer(strategy='mode').fit_transform(df)
    MasVnrArea  FireplaceQu
    Id      
    1   196.0   Gd
    974 196.0   Gd
    21  380.0   Gd
    5   350.0   TA
    651 196.0   Gd
    
    >> CustomImputer(strategy='fill', filler=[0, 'NA']).fit_transform(df)
    MasVnrArea  FireplaceQu
    Id      
    1   196.0   NA
    974 196.0   NA
    21  380.0   Gd
    5   350.0   TA
    651 0.0     Gd 
    

提交回复
热议问题