Plot trees for a Random Forest in Python with Scikit-Learn

后端 未结 5 1596
遇见更好的自我
遇见更好的自我 2020-12-08 08:34

I want to plot a decision tree of a random forest. So, i create the following code:

clf = RandomForestClassifier(n_estimators=100)
import pydotplus
import s         


        
5条回答
  •  不知归路
    2020-12-08 08:44

    After you fit a random forest model in scikit-learn, you can visualize individual decision trees from a random forest. The code below first fits a random forest model.

    import matplotlib.pyplot as plt
    from sklearn.datasets import load_breast_cancer
    from sklearn import tree
    import pandas as pd
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import train_test_split
    
    # Load the Breast Cancer Dataset
    data = load_breast_cancer()
    df = pd.DataFrame(data.data, columns=data.feature_names)
    df['target'] = data.target
    
    # Arrange Data into Features Matrix and Target Vector
    X = df.loc[:, df.columns != 'target']
    y = df.loc[:, 'target'].values
    
    # Split the data into training and testing sets
    X_train, X_test, Y_train, Y_test = train_test_split(X, y, random_state=0)
    
    # Random Forests in `scikit-learn` (with N = 100)
    rf = RandomForestClassifier(n_estimators=100,
                                random_state=0)
    rf.fit(X_train, Y_train)
    

    You can now visualize individual trees. The code below visualizes the first decision tree.

    fn=data.feature_names
    cn=data.target_names
    fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=800)
    tree.plot_tree(rf.estimators_[0],
                   feature_names = fn, 
                   class_names=cn,
                   filled = True);
    fig.savefig('rf_individualtree.png')
    

    The image below is what is saved.

    Because this question asked for trees, you can visualize all the estimators (decision trees) from a random forest if you like. The code below visualizes the first 5 from the random forest model fit above.

    # This may not the best way to view each estimator as it is small
    fn=data.feature_names
    cn=data.target_names
    fig, axes = plt.subplots(nrows = 1,ncols = 5,figsize = (10,2), dpi=900)
    for index in range(0, 5):
        tree.plot_tree(rf.estimators_[index],
                       feature_names = fn, 
                       class_names=cn,
                       filled = True,
                       ax = axes[index]);
    
        axes[index].set_title('Estimator: ' + str(index), fontsize = 11)
    fig.savefig('rf_5trees.png')
    

    The image below is what is saved.

    The code was adapted from this post.

提交回复
热议问题