How to implement a meta-estimator with the scikit-learn API?

Deadly 提交于 2021-02-19 00:23:15

问题


I would like to implement a simple wrapper / meta-estimator which is compatible with all of scikit-learn. It is hard to find a full description of what exactly I need.

The goal is to have a regressor which also learns a threshold to become a classifier. So I came up with:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

Is this implement the full API I need?

My main question is where to put the threshold. I want that it gets learned only once and can be re-used in subsequent .fit calls with new data without being readjusted. But with the current version it has to be retuned on every .fit call - which I do not want?

On the other hand, if I make it a fixed parameter self.threshold and pass it to __init__, then I'm not supposed to change it with the data?

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

来源:https://stackoverflow.com/questions/58804308/how-to-implement-a-meta-estimator-with-the-scikit-learn-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!