What does scikit-learn DecisionTreeClassifier.tree_.value do?

耗尽温柔 提交于 2019-12-02 13:28:40

问题


I am working on a DecisionTreeClassifier model and I want to understand the path chosen by the model. So I need to know what values give the

DecisionTreeClassifier.tree_.value

Thank you,


回答1:


Well, you are correct in that the documentation is actually obscure about this (but to be honest, I am not sure about its usefulness, too).

Let's replicate the example from the documentation with the iris data:

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

Asking for clf.tree_.value, we get:

array([[[ 50.,  50.,  50.]],
       [[ 50.,   0.,   0.]],
       [[  0.,  50.,  50.]],
       [[  0.,  49.,   5.]],
       [[  0.,  47.,   1.]],
       [[  0.,  47.,   0.]],
       [[  0.,   0.,   1.]],
       [[  0.,   2.,   4.]],
       [[  0.,   0.,   3.]],
       [[  0.,   2.,   1.]],
       [[  0.,   2.,   0.]],
       [[  0.,   0.,   1.]],
       [[  0.,   1.,  45.]],
       [[  0.,   1.,   2.]],
       [[  0.,   1.,   0.]],
       [[  0.,   0.,   2.]],
       [[  0.,   0.,  43.]]])

and

len(clf.tree_.value)
# 17

To realize what exactly this array represents it is useful to look at the tree visualization (also available in the docs, reproduced here for convenience):

As we can see, the tree has 17 nodes; looking closer, we see that the value of each node is actually an element of our clf.tree_.value array.

So, to make a long story short:

  • clf.tree_.value is an array of arrays, of length equal to the number of nodes in the tree
  • Each of its element arrays (which corresponds to a tree node) is of length equal to the number of classes (here 3)
  • Each of these 3-element arrays corresponds to the amount of training samples that end up in the respective node for each class.

To clarify on the last point with an example, consider the second element of the array, [[ 50., 0., 0.]] (which corresponds to the orange-colored node): it says that, in this node, end up 50 samples from the class #0, and zero samples from the other two classes (#1 and #2).

Hope this helps...



来源:https://stackoverflow.com/questions/47719001/what-does-scikit-learn-decisiontreeclassifier-tree-value-do

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!