Plot trees for a Random Forest in Python with Scikit-Learn

后端 未结 5 1564
遇见更好的自我
遇见更好的自我 2020-12-08 08:34

I want to plot a decision tree of a random forest. So, i create the following code:

clf = RandomForestClassifier(n_estimators=100)
import pydotplus
import s         


        
5条回答
  •  野趣味
    野趣味 (楼主)
    2020-12-08 08:59

    Assuming your Random Forest model is already fitted, first you should first import the export_graphviz function:

    from sklearn.tree import export_graphviz
    

    In your for cycle you could do the following to generate the dot file

    export_graphviz(tree_in_forest,
                    feature_names=X.columns,
                    filled=True,
                    rounded=True)
    

    The next line generates a png file

    os.system('dot -Tpng tree.dot -o tree.png')
    

提交回复
热议问题