How to save a custom transformer in sklearn?

柔情痞子 提交于 2020-12-02 05:55:46

问题


I am not able to load an instance of a custom transformer saved using either sklearn.externals.joblib.dump or pickle.dump because the original definition of the custom transformer is missing from the current python session.

Suppose in one python session, I define, create and save a custom transformer, it can also be loaded in the same session:

from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.externals import joblib

class CustomTransformer(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        return X


custom_transformer = CustomTransformer()    
joblib.dump(custom_transformer, 'custom_transformer.pkl')

loaded_custom_transformer = joblib.load('custom_transformer.pkl')

Opening up a new python session and loading from 'custom_transformer.pkl'

from sklearn.externals import joblib

joblib.load('custom_transformer.pkl')

raises the following exception:

AttributeError: module '__main__' has no attribute 'CustomTransformer'

The same thing is observed if joblib is replaced with pickle. Saving the custom transformer in one session with

with open('custom_transformer_pickle.pkl', 'wb') as f:
    pickle.dump(custom_transformer, f, -1)

and loading it in another:

with open('custom_transformer_pickle.pkl', 'rb') as f:
    loaded_custom_transformer_pickle = pickle.load(f)

raises the same exception.

In the above, if CustomTransformer is replaced with, say, sklearn.preprocessing.StandardScaler, then it is found that the saved instance can be loaded in a new python session.

Is it possible to be able to save a custom transformer and load it later somewhere else?


回答1:


sklearn.preprocessing.StandardScaler works because the class definition is available in the sklearn package installation, which joblib will look up when you load the pickle.

You'll have to make your CustomTransformer class available in the new session, either by re-defining or importing it.




回答2:


It works for me if I pass my transform function in sklearn.preprocessing.FunctionTranformer() and if I save the model using dill.dump() and dill.load a ".pk" file.

Note: I have included the tranform function into a sklearn pipeline with my classifier.



来源:https://stackoverflow.com/questions/46077793/how-to-save-a-custom-transformer-in-sklearn

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!