Scikit-learn how to check if model (e.g. TfidfVectorizer) has been already fit

▼魔方 西西 提交于 2021-01-27 13:40:25

问题


For feature extraction from text, how to check if a vectorizer (e.g. TfIdfVectorizer or CountVectorizer) has been already fit on a training data?
In particular, I want the code to automatically figure out if a vectorizer has been already fit.

from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()

def vectorize_data(texts):
  # if vectorizer has not been already fit
  vectorizer.fit_transform(texts)
  # else
  vectorizer.transform(texts)

回答1:


You can use the check_is_fitted which is basically made for doing this.

In the source of TfidfVectorizer.transform() you can check its usage:

def transform(self, raw_documents, copy=True):

    # This is what you need.
    check_is_fitted(self, '_tfidf', 'The tfidf vector is not fitted')

    X = super(TfidfVectorizer, self).transform(raw_documents)
    return self._tfidf.transform(X, copy=False)

So in your case, you can do this:

from sklearn.utils.validation import check_is_fitted

def vectorize_data(texts):

    try:
        check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')
    except NotFittedError:
        vectorizer.fit(texts)

    # In all cases vectorizer if fit here, so just call transform()
    vectorizer.transform(texts)



回答2:


I propose 2 ways to check this:

Personal code that covers all scikit-learn models:

import inspect

def my_inspector(model):
    return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

Let's test now this code:

from sklearn.feature_extraction.text import TfidfVectorizer
import inspect

vectorizer = TfidfVectorizer()

def my_inspector(model):
        return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

my_inspector(vectorizer)
# False

2nd way using check_is_fitted

from sklearn.utils.validation import check_is_fitted

check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')


来源:https://stackoverflow.com/questions/51369709/scikit-learn-how-to-check-if-model-e-g-tfidfvectorizer-has-been-already-fit

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!