How to perform feature selection with gridsearchcv in sklearn in python

喜夏-厌秋 提交于 2019-11-28 01:16:20

Basically you want to fine tune the hyper parameter of your classifier (with Cross validation) after feature selection using recursive feature elimination (with Cross validation).

Pipeline object is exactly meant for this purpose of assembling the data transformation and applying estimator.

May be you could use different model (GradientBoostingClassifier, etc. ) for your final classification. It would be possible with the following approach:

from sklearn.datasets import load_breast_cancer
from sklearn.feature_selection import RFECV
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)


from sklearn.pipeline import Pipeline

#this is the classifier used for feature selection
clf_featr_sele = RandomForestClassifier(n_estimators=30, random_state = 42, class_weight="balanced") 
rfecv = RFECV(estimator=clf_featr_sele, step=1, cv=5, scoring = 'roc_auc')

#you can have different classifier for your final classifier
clf = RandomForestClassifier(n_estimators=10, random_state = 42, class_weight="balanced") 
CV_rfc = GridSearchCV(clf, param_grid={'max_depth':[2,3]}, cv= 5, scoring = 'roc_auc')

pipeline  = Pipeline([('feature_sele',rfecv),('clf_cv',CV_rfc)])

pipeline.fit(X_train, y_train)
pipeline.predict(X_test)

Now, you can apply this pipeline (Including feature selection) for test data.

You just need to pass the Recursive Feature Elimination Estimator directly into the GridSearchCV object. Something like this should work

X = df[my_features] #all my features
y = df['gold_standard'] #labels

clf = RandomForestClassifier(random_state = 42, class_weight="balanced")
rfecv = RFECV(estimator=clf, step=1, cv=StratifiedKFold(10), scoring='auc_roc')

param_grid = { 
    'n_estimators': [200, 500],
    'max_features': ['auto', 'sqrt', 'log2'],
    'max_depth' : [4,5,6,7,8],
    'criterion' :['gini', 'entropy']
}
k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

#------------- Just pass your RFECV object as estimator here directly --------#

CV_rfc = GridSearchCV(estimator=rfecv, param_grid=param_grid, cv= k_fold, scoring = 'roc_auc')


CV_rfc.fit(x_train, y_train)
print(CV_rfc.best_params_)
print(CV_rfc.best_score_)
print(CV_rfc.best_estimator_)

You can do what you want by prefixing the names of the parameters you want to pass to the estimator with 'estimator__'.

X = df[[my_features]]
y = df[gold_standard]

clf = RandomForestClassifier(random_state=0, class_weight="balanced")
rfecv = RFECV(estimator=clf, step=1, cv=StratifiedKFold(3), scoring='roc_auc')

param_grid = { 
    'estimator__n_estimators': [200, 500],
    'estimator__max_features': ['auto', 'sqrt', 'log2'],
    'estimator__max_depth' : [4,5,6,7,8],
    'estimator__criterion' :['gini', 'entropy']
}
k_fold = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)

CV_rfc = GridSearchCV(estimator=rfecv, param_grid=param_grid, cv= k_fold, scoring = 'roc_auc')

X_train, X_test, y_train, y_test = train_test_split(X, y)

CV_rfc.fit(X_train, y_train)

Output on fake data I made:

{'estimator__n_estimators': 200, 'estimator__max_depth': 6, 'estimator__criterion': 'entropy', 'estimator__max_features': 'auto'}
0.5653035605690997
RFECV(cv=StratifiedKFold(n_splits=3, random_state=None, shuffle=False),
   estimator=RandomForestClassifier(bootstrap=True, class_weight='balanced',
            criterion='entropy', max_depth=6, max_features='auto',
            max_leaf_nodes=None, min_impurity_decrease=0.0,
            min_impurity_split=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=200, n_jobs=None, oob_score=False, random_state=0,
            verbose=0, warm_start=False),
   min_features_to_select=1, n_jobs=None, scoring='roc_auc', step=1,
   verbose=0)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!