How do I find which attributes my tree splits on, when using scikit-learn?

前端 未结 3 1205
余生分开走
余生分开走 2020-12-23 02:35

I have been exploring scikit-learn, making decision trees with both entropy and gini splitting criteria, and exploring the differences.

My question, is how can I \"o

3条回答
  •  刺人心
    刺人心 (楼主)
    2020-12-23 03:06

    Scikit learn introduced a delicious new method called export_text in version 0.21 (May 2019) to view all the rules from a tree. Documentation here.

    Once you've fit your model, you just need two lines of code. First, import export_text:

    from sklearn.tree.export import export_text
    

    Second, create an object that will contain your rules. To make the rules look more readable, use the feature_names argument and pass a list of your feature names. For example, if your model is called model and your features are named in a dataframe called X_train, you could create an object called tree_rules:

    tree_rules = export_text(model, feature_names=list(X_train))
    

    Then just print or save tree_rules. Your output will look like this:

    |--- Age <= 0.63
    |   |--- EstimatedSalary <= 0.61
    |   |   |--- Age <= -0.16
    |   |   |   |--- class: 0
    |   |   |--- Age >  -0.16
    |   |   |   |--- EstimatedSalary <= -0.06
    |   |   |   |   |--- class: 0
    |   |   |   |--- EstimatedSalary >  -0.06
    |   |   |   |   |--- EstimatedSalary <= 0.40
    |   |   |   |   |   |--- EstimatedSalary <= 0.03
    |   |   |   |   |   |   |--- class: 1
    

提交回复
热议问题