Parameters are not going to custom estimator in scikit-learn GridSearchCV

前端 未结 1 415
甜味超标
甜味超标 2021-01-15 02:21

I\'m trying and failing to pass parameters to a custom estimator in scikit learn. I\'d like the parameter lr to change during the gridsearch. Problem is that th

相关标签:
1条回答
  • 2021-01-15 03:00

    You were not able to see the change in lr value since you are printing inside constructor function.

    If we print inside .fit() function, we can see the change of lr values. It happens because of the way the different copies of estimators are created. See here to understand the process for creating multiple copies.

    from sklearn.model_selection import GridSearchCV
    from sklearn.base import BaseEstimator, ClassifierMixin
    import numpy as np
    
    class MyClassifier(BaseEstimator, ClassifierMixin):
    
        def __init__(self, lr=0):
             # Some code
            print('lr:', lr)
            self.lr = lr
    
        def fit(self, X, y):
             # Some code
            print('lr:', self.lr)
            return self
    
        def predict(self, X):
             # Some code
             return X % 3
    
    params = {
        'lr': [0.1, 0.5, 0.7]
    }
    gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
    
    x = np.arange(30)
    y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
    gs.fit(x, y)
    gs.predict(x)
    

    Output:

    lr: 0
    lr: 0
    lr: 0
    lr: 0.1
    lr: 0
    lr: 0.1
    lr: 0
    lr: 0.1
    lr: 0
    lr: 0.1
    lr: 0
    lr: 0.5
    lr: 0
    lr: 0.5
    lr: 0
    lr: 0.5
    lr: 0
    lr: 0.5
    lr: 0
    lr: 0.7
    lr: 0
    lr: 0.7
    lr: 0
    lr: 0.7
    lr: 0
    lr: 0.7
    lr: 0
    lr: 0.1
    
    0 讨论(0)
提交回复
热议问题