Override method for a collection of classes implementing an interface

谁说胖子不能爱 提交于 2019-12-13 15:26:15

问题


I am using scikit-learn and am building a pipeline. Once the pipeline is build, I am using GridSearchCV to find the optimal model. I am working with text data, so I am experimenting with different stemmers. I have created a class called Preprocessor that takes a stemmer and vectorizer class, then attempts to override the vectorizer's method build_analyzer to incorporate the given stemmer. However, I see that GridSearchCV's set_params just directly accesses instance variables -- i.e. it will not re-instantiate a vectorizer with a new analyzer, as I have been doing it:

class Preprocessor(object):
    # hard code the stopwords for now
    stopwords = nltk.corpus.stopwords.words()

    def __init__(self, stemmer_cls, vectorizer_cls):
        self.stemmer = stemmer_cls()
        analyzer = self._build_analyzer(self.stemmer, vectorizer_cls)
        self.vectorizer = vectorizer_cls(stopwords=stopwords,
                                         analyzer=analyzer,
                                         decode_error='ignore')

    def _build_analyzer(self, stemmer, vectorizer_cls):
        # analyzer tokenizes and lowercases
        analyzer = super(vectorizer_cls, self).build_analyzer()
        return lambda doc: (stemmer.stem(w) for w in analyzer(doc))

    def fit(self, **kwargs):
        return self.vectorizer.fit(kwargs)

    def transform(self, **kwargs):
        return self.vectorizer.transform(kwargs)

    def fit_transform(self, **kwargs):
        return self.vectorizer.fit_transform(kwargs)

So the question is: how can I override a build_analyzer for all vectorizer classes passed in?


回答1:


Yes, GridSearchCV directly sets instance fields, and then calls fit on classifier with changed fields.

Every classifier in scikit-learn was built in such a way, that __init__ only sets parameter fields, and all dependent objects needed for further work (like calling _build_analyzer in your case) is constructed only inside fit method. You have to add additional field which stores vectorizer_cls, then you have to construct dependent from vectorized_cls and stemmer_cls objects inside fit method.

Something like:

class Preprocessor(object):
    # hard code the stopwords for now
    stopwords = nltk.corpus.stopwords.words()

    def __init__(self, stemmer_cls, vectorizer_cls):
        self.stemmer_cls = stemmer_cls
        self.vectorizer_cls = vectorizer_cls

    def _build_analyzer(self, stemmer, vectorizer_cls):
        # analyzer tokenizes and lowercases
        analyzer = super(vectorizer_cls, self).build_analyzer()
        return lambda doc: (stemmer.stem(w) for w in analyzer(doc))

    def fit(self, **kwargs):
        analyzer = self._build_analyzer(self.stemmer_cls(), vectorizer_cls)
        self.vectorizer_cls = vectorizer_cls(stopwords=stopwords,
                                         analyzer=analyzer,
                                         decode_error='ignore')

        return self.vectorizer_cls.fit(kwargs)

    def transform(self, **kwargs):
        return self.vectorizer_cls.transform(kwargs)

    def fit_transform(self, **kwargs):
        return self.vectorizer_cls.fit_transform(kwargs)


来源:https://stackoverflow.com/questions/32343024/override-method-for-a-collection-of-classes-implementing-an-interface

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!