scikit-learn: applying an arbitary function as part of a pipeline

ε祈祈猫儿з 提交于 2019-12-04 13:21:03

For a general solution (working for many other use cases, not just transformers, but also simple models etc.), you can write your own decorator if you have state-free functions (which do not implement fit), for example by doing:

class TransformerWrapper(sklearn.base.BaseEstimator):

    def __init__(self, func):
        self._func = func

    def fit(self, *args, **kwargs):
        return self

    def transform(self, X, *args, **kwargs):
        return self._func(X, *args, **kwargs)

and now you can do

@TransformerWrapper
def foo(x):
  return x*2

which is equivalent of doing

def foo(x):
  return x*2

foo = TransformerWrapper(foo)

which is what sklearn.preprocessing.FunctionTransformer is doing under the hood.

Personally I find decorating simpler, since you have a nice separation of your preprocessors from the rest of the code, but it is up to you which path to follow.

In fact you should be able to decorate with sklearn function by

from sklearn.preprocessing import FunctionTransformer

@FunctionTransformer
def foo(x):
  return x*2

too.

The sklearn.preprocessing.FunctionTransformer class can be used to instantiate a scikit-learn transformer (which can be used e.g. in a pipeline) from a user provided function.

I think it's worth to mention that sklearn.preprocessing.FunctionTransformer(..., validate=True) has a validate=False parameter:

validate : bool, optional default=True

Indicate that the input X array should be checked before calling func. If validate is false, there will be no input validation. If it is true, then X will be converted to a 2-dimensional NumPy array or sparse matrix. If this conversion is not possible or X contains NaN or infinity, an exception is raised.

So if you are going to pass non-numerical features to FunctionTransformer make sure that you explicitly set validate=False, otherwise it'll fail with the following exception:

ValueError: could not convert string to float: 'your non-numerical value'
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!