Where does scikit-learn hold the decision labels of each leaf node in its tree structure?

夙愿已清 提交于 2019-12-23 07:26:27

问题


I have trained a random forest model using scikit-learn and now I want to save its tree structures in a text file so I can use it elsewhere. According to this link a tree object consist of a number of parallel arrays each one hold some information about different nodes of the tree (ex. left child, right child, what feature it examines,...) . However there seems to be no information about the class label corresponding to each leaf node! It's even not mentioned in the examples provided in the link above.

Does anyone know where are the class labels stored in the scikit-learn decision tree structure?


回答1:


Take a look at the docs for sklearn.tree.DecisionTreeClassifier.tree_.value:

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]

Each row in clf.tree_.value "contains the constant prediction value of each node," (help(clf.tree_)) which corresponds index-to-index to clf.classes_.

See this answer for (barely) more details.



来源:https://stackoverflow.com/questions/44158993/where-does-scikit-learn-hold-the-decision-labels-of-each-leaf-node-in-its-tree-s

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!