Saving an sklearn `FunctionTransformer` with the function it wraps

柔情痞子 提交于 2019-12-08 19:13:46

问题


I am using sklearn's Pipeline and FunctionTransformer with a custom function

from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline

This is my code:

def f(x):
    return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception

And I get this error:

AttributeError: module '__ main__' has no attribute 'f'

How can this be resolved ?

Note that this issue occurs also in pickle


回答1:


I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.

import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin

class MyFunctionTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, f):
        self.func = f
    def __call__(self, X):
        return self.func(X)
    def __getstate__(self):
        self.func_name = self.func.__name__
        self.func_code = marshal.dumps(self.func.__code__)
        del self.func
        return self.__dict__
    def __setstate__(self, d):
        d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
        del d["func_name"]
        del d["func_code"]
        self.__dict__ = d
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return self.func(X)

Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:

from sklearn.externals import joblib
from sklearn.pipeline import Pipeline

@MyFunctionTransformer
def my_transform(x):
    return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")

The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.

dill also looks like a good alternative to marshaling



来源:https://stackoverflow.com/questions/54012769/saving-an-sklearn-functiontransformer-with-the-function-it-wraps

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!