How to explore a decision tree built using scikit learn

偶尔善良 提交于 2019-12-20 09:19:57

问题


I am building a decision tree using

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

This all works fine. However, how do I then explore the decision tree?

For example, how do I find which entries from X_train appear in a particular leaf?


回答1:


You need to use the predict method.

After training the tree, you feed the X values to predict their output.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data) 

output:

>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

To get details on the tree structure, we can use tree_.__getstate__()

Tree structure translated into an "ASCII art" picture

              0  
        _____________
        1           2
               ______________
               3            12
            _______      _______
            4     7      13   16
           ___   ______        _____
           5 6   8    9        14 15
                      _____
                      10 11

tree structure as an array.

In [38]: tree.tree_.__getstate__()['nodes']
Out[38]: 
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
       (-1, -1, -2, -2.0, 0.0, 50, 50.0),
       (3, 12, 3, 1.75, 0.5, 100, 100.0),
       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
       (-1, -1, -2, -2.0, 0.0, 47, 47.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
       (-1, -1, -2, -2.0, 0.0, 3, 3.0),
       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (-1, -1, -2, -2.0, 0.0, 43, 43.0)], 
      dtype=[('left_child', '<i8'), ('right_child', '<i8'), 
             ('feature', '<i8'), ('threshold', '<f8'), 
             ('impurity', '<f8'), ('n_node_samples', '<i8'), 
             ('weighted_n_node_samples', '<f8')])

Where:

  • The first node [0] is the root node.
  • internal nodes have left_child and right_child refering to nodes with positive values, and greater than the current node.
  • leaves have -1 value for the left and right child nodes.
  • nodes 1,5,6, 8,10,11,14,15,16 are leaves.
  • the node structure is built using the Depth First Search Algorithm.
  • the feature field tells us which of the iris.data features was used in the node to determine the path for this sample.
  • the threshold tells us the value used to evaluate the direction based on the feature.
  • impurity reaches 0 at the leaves... since all the samples are in the same class once you reach the leaf.
  • n_node_samples tells us how many samples reach each leaf.

Using this information we could trivially track each sample X to the leaf where it eventually lands by following the classification rules and thresholds on a script. Additionally, the n_node_samples would allow us to perform unit tests ensuring that each node gets the correct number of samples.Then using the output of tree.predict, we could map each leaf to the associated class.




回答2:


NOTE: This is not an answer, only a hint on possible solutions.

I encountered a similar problem recently in my project. My goal is to extract the corresponding chain of decisions for some particular samples. I think your problem is a subset of mine, since you just need to record the last step in the decision chain.

Up to now, it seems the only viable solution is to write a custom predict method in Python to keep track of the decisions along the way. The reason is that the predict method provided by scikit-learn cannot do this out-of-box (as far as I know). And to make it worse, it is a wrapper for C implementation which is pretty hard to customize.

Customization is fine for my problem, since I'm dealing with a unbalanced dataset, and the samples I care about (positive ones) are rare. So I can filter them out first using sklearn predict and then get the decision chain using my customization.

However, this may not work for you if you have a large dataset. Because if you parse the tree and do predict in Python, it will run slow in Python speed and will not (easily) scale. You may have to fallback to customizing the C implementation.




回答3:


The below code should produce a plot of your top ten features:

import numpy as np
import matplotlib.pyplot as plt

importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(10):
    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()

Taken from here and modified slightly to fit the DecisionTreeClassifier.

This doesn't exactly help you explore the tree, but it does tell you about the tree.




回答4:


This code will do exactly what you want. Here, n is the number observations in X_train. At the end, the (n,number_of_leaves)-sized array leaf_observations holds in each column boolean values for indexing into X_train to get the observations in each leaf. Each columns of leaf_observations corresponds to an element in leaves, which has the node IDs for the leaves.

# get the nodes which are leaves
leaves = clf.tree_.children_left == -1
leaves = np.arange(0,clf.tree_.node_count)[leaves]

# loop through each leaf and figure out the data in it
leaf_observations = np.zeros((n,len(leaves)),dtype=bool)
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [clf.tree_.feature.tolist()]
thistree.append(clf.tree_.threshold.tolist())
thistree.append(clf.tree_.children_left.tolist())
thistree.append(clf.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
    # get the decision rules in numeric list form
    rules = []
    RevTraverseTree(thistree, nod, rules)
    # convert & apply to the data by sequentially &ing the rules
    thisnode = np.ones(n,dtype=bool)
    for rule in rules:
        if rule[1] == 1:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] > rule[2])
        else:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] <= rule[2])
    # get the observations that obey all the rules - they are the ones in this leaf node
    leaf_observations[:,ind] = thisnode

This needs the helper function defined here, which recursively traverses the tree starting from a specified node to build the decision rules.

def RevTraverseTree(tree, node, rules):
    '''
    Traverase an skl decision tree from a node (presumably a leaf node)
    up to the top, building the decision rules. The rules should be
    input as an empty list, which will be modified in place. The result
    is a nested list of tuples: (feature, direction (left=-1), threshold).  
    The "tree" is a nested list of simplified tree attributes:
    [split feature, split threshold, left node, right node]
    '''
    # now find the node as either a left or right child of something
    # first try to find it as a left node
    try:
        prevnode = tree[2].index(node)
        leftright = -1
    except ValueError:
        # failed, so find it as a right node - if this also causes an exception, something's really f'd up
        prevnode = tree[3].index(node)
        leftright = 1
    # now let's get the rule that caused prevnode to -> node
    rules.append((tree[0][prevnode],leftright,tree[1][prevnode]))
    # if we've not yet reached the top, go up the tree one more step
    if prevnode != 0:
        RevTraverseTree(tree, prevnode, rules)



回答5:


I've changed a bit what Dr. Drew posted.
The following code, given a data frame and the decision tree after being fitted, returns:

  • rules_list: a list of rules
  • values_path: a list of entries (entries for each class going through the path)

    import numpy as np  
    import pandas as pd  
    from sklearn.tree import DecisionTreeClassifier 
    
    def get_rules(dtc, df):
        rules_list = []
        values_path = []
        values = dtc.tree_.value
    
        def RevTraverseTree(tree, node, rules, pathValues):
            '''
            Traverase an skl decision tree from a node (presumably a leaf node)
            up to the top, building the decision rules. The rules should be
            input as an empty list, which will be modified in place. The result
            is a nested list of tuples: (feature, direction (left=-1), threshold).  
            The "tree" is a nested list of simplified tree attributes:
            [split feature, split threshold, left node, right node]
            '''
            # now find the node as either a left or right child of something
            # first try to find it as a left node            
    
            try:
                prevnode = tree[2].index(node)           
                leftright = '<='
                pathValues.append(values[prevnode])
            except ValueError:
                # failed, so find it as a right node - if this also causes an exception, something's really f'd up
                prevnode = tree[3].index(node)
                leftright = '>'
                pathValues.append(values[prevnode])
    
            # now let's get the rule that caused prevnode to -> node
            p1 = df.columns[tree[0][prevnode]]    
            p2 = tree[1][prevnode]    
            rules.append(str(p1) + ' ' + leftright + ' ' + str(p2))
    
            # if we've not yet reached the top, go up the tree one more step
            if prevnode != 0:
                RevTraverseTree(tree, prevnode, rules, pathValues)
    
        # get the nodes which are leaves
        leaves = dtc.tree_.children_left == -1
        leaves = np.arange(0,dtc.tree_.node_count)[leaves]
    
        # build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
        thistree = [dtc.tree_.feature.tolist()]
        thistree.append(dtc.tree_.threshold.tolist())
        thistree.append(dtc.tree_.children_left.tolist())
        thistree.append(dtc.tree_.children_right.tolist())
    
        # get the decision rules for each leaf node & apply them
        for (ind,nod) in enumerate(leaves):
    
            # get the decision rules
            rules = []
            pathValues = []
            RevTraverseTree(thistree, nod, rules, pathValues)
    
            pathValues.insert(0, values[nod])      
            pathValues = list(reversed(pathValues))
    
            rules = list(reversed(rules))
    
            rules_list.append(rules)
            values_path.append(pathValues)
    
        return (rules_list, values_path)
    

It follows an example:

df = pd.read_csv('df.csv')

X = df[df.columns[:-1]]
y = df['classification']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

dtc = DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)

The Decision Tree fitted has generated the following tree: Decision Tree with width 2

At this point, just calling the function:

get_rules(dtc, df)

This is what the function returns:

rules = [  
    ['first <= 63.5', 'first <= 43.5'],  
    ['first <= 63.5', 'first > 43.5'],  
    ['first > 63.5', 'second <= 19.700000762939453'],  
    ['first > 63.5', 'second > 19.700000762939453']
]

values = [
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 284.,  57.]])],
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 352.,  184.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 645.,  620.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 287.,  708.]])]
]

Obviously, in values, for each path, there is the leaf values too.




回答6:


I think an easy option would be to use the apply method of the trained decision tree. Train the tree, apply the traindata and build a lookup table from the returned indices:

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

# apply training data to decision tree
leaf_indices = clf.apply(iris.data)
lookup = {}

# build lookup table
for i, leaf_index in enumerate(leaf_indices):
    try:
        lookup[leaf_index].append(iris.data[i])
    except KeyError:
        lookup[leaf_index] = []
        lookup[leaf_index].append(iris.data[i])

# test
unkown_sample = [[4., 3.1, 6.1, 1.2]]
index = clf.apply(unkown_sample)
print(lookup[index[0]])



回答7:


Have you tried dumping your DecisionTree into a graphviz' .dot file [1] and then load it with graph_tool [2].:

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from graph_tool.all import *

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

tree.export_graphviz(clf,out_file='tree.dot')

#load graph with graph_tool and explore structure as you please
g = load_graph('tree.dot')

for v in g.vertices():
   for e in v.out_edges():
       print(e)
   for w in v.out_neighbours():
       print(w)

[1] http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

[2] https://graph-tool.skewed.de/



来源:https://stackoverflow.com/questions/32506951/how-to-explore-a-decision-tree-built-using-scikit-learn

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!