问题
I am using sklearn's Pipeline and FunctionTransformer with a custom function
from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
This is my code:
def f(x):
return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception
And I get this error:
AttributeError: module '__ main__' has no attribute 'f'
How can this be resolved ?
Note that this issue occurs also in pickle
回答1:
I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.
import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin
class MyFunctionTransformer(BaseEstimator, TransformerMixin):
def __init__(self, f):
self.func = f
def __call__(self, X):
return self.func(X)
def __getstate__(self):
self.func_name = self.func.__name__
self.func_code = marshal.dumps(self.func.__code__)
del self.func
return self.__dict__
def __setstate__(self, d):
d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
del d["func_name"]
del d["func_code"]
self.__dict__ = d
def fit(self, X, y=None):
return self
def transform(self, X):
return self.func(X)
Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:
from sklearn.externals import joblib
from sklearn.pipeline import Pipeline
@MyFunctionTransformer
def my_transform(x):
return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")
The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.
dill also looks like a good alternative to marshaling
来源:https://stackoverflow.com/questions/54012769/saving-an-sklearn-functiontransformer-with-the-function-it-wraps