how to explain the decision tree from scikit-learn

时光毁灭记忆、已成空白 提交于 2019-11-28 17:54:38

The value line in each box is telling you how many samples at that node fall into each category, in order. That's why, in each box, the numbers in value add up to the number shown in sample. For instance, in your red box, 91+212+113=416. So this means if you reach this node, there were 91 data points in category 1, 212 in category 2, and 113 in category 3.

If you were going to predict the outcome for a new data point that reached that leaf in the decision tree, you would predict category 2, because that is the most common category for samples at that node.

First question: Yes, your logic is correct. The left node is True and the right node is False. This can be counter-intuitive; true can equate to a smaller sample.

Second question: This problem is best resolved by visualizing the tree as a graph with pydotplus. The 'class_names' attribute of tree.export_graphviz() will add a class declaration to the majority class of each node. Code is executed in an iPython notebook.

from sklearn.datasets import load_iris  
from sklearn import tree  
iris = load_iris()  
clf2 = tree.DecisionTreeClassifier()  
clf2 = clf2.fit(iris.data, iris.target)  

with open("iris.dot", 'w') as f:  
    f = tree.export_graphviz(clf, out_file=f)  

import os  
os.unlink('iris.dot')  

import pydotplus  
dot_data = tree.export_graphviz(clf2, out_file=None)  
graph2 = pydotplus.graph_from_dot_data(dot_data)  
graph2.write_pdf("iris.pdf")  

from IPython.display import Image  
dot_data = tree.export_graphviz(clf2, out_file=None,  
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  # leaves_parallel=True, 
                     special_characters=True)  
graph2 = pydotplus.graph_from_dot_data(dot_data)

## Color of nodes
nodes = graph2.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
        color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
        values = color[values.index(max(values))]; # print(values)
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
        node.set_fillcolor(color )
#

Image(graph2.create_png() ) 

As for determining the class at the leaf, your example doesn't have leaves with a single class, as the iris data set does. This is common and may require over-fitting the model to attain such an outcome. A discrete distribution of classes is best result for many cross-validated models.

Enjoy the code!

According to the book "Learning scikit-learn: Machine Learning in Python", The decision tree represents a series of decisions based on the training data.

!(http://i.imgur.com/vM9fJLy.png)

To classify an instance, we should answer the question at each node. For example, Is sex<=0.5? (are we talking about a woman?). If the answer is yes, you go to the left child node in the tree; otherwise you go to the right child node. You keep answering questions (was she in the third class?, was she in the first class?, and was she below 13 years old?), until you reach a leaf. When you are there, the prediction corresponds to the target class that has most instances.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!