How to implement a meta-estimator with the scikit-learn API?

Deadly 提交于 2021-02-19 00:23:15


I would like to implement a simple wrapper / meta-estimator which is compatible with all of scikit-learn. It is hard to find a full description of what exactly I need.

The goal is to have a regressor which also learns a threshold to become a classifier. So I came up with:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

Is this implement the full API I need?

My main question is where to put the threshold. I want that it gets learned only once and can be re-used in subsequent .fit calls with new data without being readjusted. But with the current version it has to be retuned on every .fit call - which I do not want?

On the other hand, if I make it a fixed parameter self.threshold and pass it to __init__, then I'm not supposed to change it with the data?

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

