How to visualize an XGBoost tree from GridSearchCV output?

大憨熊 提交于 2020-06-13 05:56:06

问题


I am using XGBRegressor to fit the model using gridsearchcv. I want to visulaize the trees.

Here is the link I followed ( If duplicate) how to plot a decision tree from gridsearchcv?

xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
folds = 5
grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4, verbose=3 )
model=grid.fit(X_train, y_train)

Approach 1:

 dot_data = tree.export_graphviz(model.best_estimator_, out_file=None, 
        filled=True, rounded=True, feature_names=X_train.columns)
 dot_data

 Error: NotFittedError: This XGBRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

Approach 2:

tree.export_graphviz(best_clf, out_file='tree.dot',feature_names=X_train.columns,leaves_parallel=True)
subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])

Same error.


回答1:


scikit-learn's tree.export_graphviz will not work here, because your best_estimator_ is not a single tree, but a whole ensemble of trees.

Here is how you can do it using XGBoost's own plot_tree and the Boston housing data:

from xgboost import XGBRegressor, plot_tree
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt

X, y = load_boston(return_X_y=True)

params = {'learning_rate':[0.1, 0.5], 'n_estimators':[5, 10]} # dummy, for demonstration only

xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4)

grid.fit(X, y)

Our best estimator is:

grid.best_estimator_
# result (details may be different due to randomness):
XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
             colsample_bynode=1, colsample_bytree=1, gamma=0,
             importance_type='gain', learning_rate=0.5, max_delta_step=0,
             max_depth=3, min_child_weight=1, missing=None, n_estimators=10,
             n_jobs=1, nthread=1, objective='reg:linear', random_state=0,
             reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
             silent=True, subsample=1, verbosity=1)

Having done that, and utilizing the answer from this SO thread to plot, say, tree #4:

fig, ax = plt.subplots(figsize=(30, 30))
plot_tree(grid.best_estimator_, num_trees=4, ax=ax)
plt.show()

Similarly, for tree #1:

fig, ax = plt.subplots(figsize=(30, 30))
plot_tree(grid.best_estimator_, num_trees=1, ax=ax)
plt.show()



来源:https://stackoverflow.com/questions/62176516/how-to-visualize-an-xgboost-tree-from-gridsearchcv-output

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!