Sklearn: Is there any way to debug Pipelines?

后端 未结 2 565
甜味超标
甜味超标 2021-01-01 22:26

I have created some pipelines for classification task and I want to check out what information is being present/stored at each stage (e.g. text_stats, ngram_tfidf). How coul

相关标签:
2条回答
  • 2021-01-01 23:12

    You can traverse your Pipeline() tree using steps and named_steps attributes. The former is a list of tuples ('step_name', Step()) while the latter gives you a dictionary constructed from this list

    FeatureUnion() content could be explored the same way using transformer_list attribute

    0 讨论(0)
  • 2021-01-01 23:13

    I find it at times useful to temporarily add a debugging step that prints out the information you are interested in. Building on top of the example from the sklearn example 1, you could do this to for example to print out the first 5 lines, shape, or whatever you need to look at before the classifier is called:

    from sklearn import svm
    from sklearn.datasets import samples_generator
    from sklearn.feature_selection import SelectKBest
    from sklearn.feature_selection import f_regression
    from sklearn.pipeline import Pipeline
    from sklearn.base import TransformerMixin, BaseEstimator
    
    class Debug(BaseEstimator, TransformerMixin):
    
        def transform(self, X):
            print(pd.DataFrame(X).head())
            print(X.shape)
            return X
    
        def fit(self, X, y=None, **fit_params):
            return self
    
    X, y = samples_generator.make_classification(n_informative=5, n_redundant=0, random_state=42)
    anova_filter = SelectKBest(f_regression, k=5)
    clf = svm.SVC(kernel='linear')
    anova_svm = Pipeline([('anova', anova_filter), ('dbg', Debug()), ('svc', clf)])
    anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y)
    
    prediction = anova_svm.predict(X)
    
    0 讨论(0)
提交回复
热议问题