Plot trees for a Random Forest in Python with Scikit-Learn

后端 未结 5 1598
遇见更好的自我
遇见更好的自我 2020-12-08 08:34

I want to plot a decision tree of a random forest. So, i create the following code:

clf = RandomForestClassifier(n_estimators=100)
import pydotplus
import s         


        
5条回答
  •  醉话见心
    2020-12-08 09:03

    To access the single decision tree from the random forest in scikit-learn use estimators_ attribute:

    rf = RandomForestClassifier()
    # first decision tree
    rf.estimators_[0]
    

    Then you can use standard way to visualize the decision tree:

    • you can print the tree representation, with sklearn export_text
    • export to graphiviz and plot with sklearn export_graphviz method
    • plot with matplotlib with sklearn plot_tree method
    • use dtreeviz package for tree plotting

    The code with example output are described in this post.

    The important thing to while plotting the single decision tree from the random forest is that it might be fully grown (default hyper-parameters). It means the tree can be really depth. For me, the tree with depth greater than 6 is very hard to read. So if the tree visualization will be needed I'm building random forest with max_depth < 7. You can check the example visualization in this post.

提交回复
热议问题