How to implement SMOTE in cross validation and GridSearchCV

后端 未结 2 738
独厮守ぢ
独厮守ぢ 2021-02-03 11:22

I\'m relatively new to Python. Can you help me improve my implementation of SMOTE to a proper pipeline? What I want is to apply the over and under sampling on the training set o

2条回答
  •  半阙折子戏
    2021-02-03 11:54

    You need to look at the pipeline object. imbalanced-learn has a Pipeline which extends the scikit-learn Pipeline, to adapt for the fit_sample() and sample() methods in addition to fit_predict(), fit_transform() and predict() methods of scikit-learn.

    Have a look at this example here:

    • https://imbalanced-learn.org/stable/auto_examples/pipeline/plot_pipeline_classification.html

    For your code, you would want to do this:

    from imblearn.pipeline import make_pipeline, Pipeline
    
    smote_enn = SMOTEENN(smote = sm)
    clf_rf = RandomForestClassifier(n_estimators=25, random_state=1)
    
    pipeline = make_pipeline(smote_enn, clf_rf)
        OR
    pipeline = Pipeline([('smote_enn', smote_enn),
                         ('clf_rf', clf_rf)])
    

    Then you can pass this pipeline object to GridSearchCV, RandomizedSearchCV or other cross validation tools in the scikit-learn as a regular object.

    kf = StratifiedKFold(n_splits=n_splits)
    random_search = RandomizedSearchCV(pipeline, param_distributions=param_dist,
                                       n_iter=1000, 
                                       cv = kf)
    

提交回复
热议问题