More than one estimator in GridSearchCV(sklearn)

前端 未结 1 472
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-02 06:52

I was checking sklearn documentation webpage about GridSearchCV. One of attributes of GridSearchCV object is best_estimator_. So here

相关标签:
1条回答
  • 2021-01-02 07:28

    GridSearchCV works on parameters. It will train multiple estimators (but same class (one of SVC, or DecisionTreeClassifier, or other classifiers) with different parameter combinations from specified in param_grid. best_estimator_ is the estimator which performs best on the data.

    So essentially best_estimator_ is the same class object initialized with best found params.

    So in the basic setup you cannot use multiple estimators in the grid-search.

    But as a workaround, you can have multiple estimators when using a pipeline in which the estimator is a "parameter" which the GridSearchCV can set.

    Something like this:

    from sklearn.pipeline import Pipeline
    from sklearn.svm import SVC
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import GridSearchCV
    from sklearn.datasets import load_iris
    iris_data = load_iris()
    X, y = iris_data.data, iris_data.target
    
    
    # Just initialize the pipeline with any estimator you like    
    pipe = Pipeline(steps=[('estimator', SVC())])
    
    # Add a dict of estimator and estimator related parameters in this list
    params_grid = [{
                    'estimator':[SVC()],
                    'estimator__C': [1, 10, 100, 1000],
                    'estimator__gamma': [0.001, 0.0001],
                    },
                    {
                    'estimator': [DecisionTreeClassifier()],
                    'estimator__max_depth': [1,2,3,4,5],
                    'estimator__max_features': [None, "auto", "sqrt", "log2"],
                    },
                   # {'estimator':[Any_other_estimator_you_want],
                   #  'estimator__valid_param_of_your_estimator':[valid_values]
    
                  ]
    
    grid = GridSearchCV(pipe, params_grid)
    

    You can add as many dicts inside the list of params_grid as you like, but make sure that each dict have compatible parameters related to the 'estimator'.

    0 讨论(0)
提交回复
热议问题