Using custom estimator with cross_val_score fails

|▌冷眼眸甩不掉的悲伤 提交于 2021-02-10 12:29:09

问题


I am trying to use cross_val_score with a customized estimator. It is important that this estimator receives a member variable which can be used later inside the fit function. But it seems inside cross_val_score the member variables are destroyed (or a new instance of the estimator is being created). Here is the minimal code which can reproduce the error:

from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator

class MyEstimator(BaseEstimator):
    def __init__(self, member):
        self._member = member

    def fit(self, X, y):
        if self._member is None:
            raise Exception('member is None.')

X = np.array([[1, 1, 1], [2 ,2 , 2]])
y = np.array([1, 2])

score_values = cross_val_score(
            MyEstimator('some value'),
            X,
            y,
            cv=2, 
            scoring='r2'
        )

In the above code the Exception is always raised. Is there a way to solve this?


回答1:


Sklearn clones the estimator internally, to create multiple copies of the estimator. Reference; using clone function.

from sklearn.base import clone
t = MyEstimator('some value')
t1 = clone(t)
t._member, t1._member
#
('some value', None)

clone copies constructor parameter values only from the objects.

Solution:

Make your constructor parameter and object attributes consistent hence start with underscore or remove the underscore everywhere!

class MyEstimator(BaseEstimator):
    def __init__(self, member):
        self.member = member

    def fit(self, X, y):
        if self.member is None:
            raise Exception('member is None.')

    def predict(self, X):
        return [1]

X = np.array([[1, 1, 1], [2 ,2 , 2],[3,3,3]])
y = np.array([1, 2,3])

score_values = cross_val_score(
            MyEstimator('some value'),
            X,
            y,
            cv=3, 
            scoring='r2',error_score='raise'
        )


来源:https://stackoverflow.com/questions/54636101/using-custom-estimator-with-cross-val-score-fails

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!