Getting decision path to a node in sklearn

前端 未结 2 771
轻奢々
轻奢々 2020-12-31 23:04

I wanted the decision path (i.e the set of rules) from the root node to a given node (which I supply) in a decision tree (DecisionTreeClassifier) in scikit-learn. clf.

2条回答
  •  不知归路
    2020-12-31 23:30

    If you supply None to the out_file in export_graphviz, you can get a string representation of the tree.

    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier()
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    string_data = tree.export_graphviz(clf,
        out_file=None)
    
    print(string_data)
    
    #Output
    digraph Tree {
    node [shape=box] ;
    0 [label="petal length (cm) <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"] ;
    1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"] ;
    0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
    2 [label="petal width (cm) <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"] ;
    0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
    3 [label="petal length (cm) <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"] ;
    2 -> 3 ;
    4 [label="petal width (cm) <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]\nclass = versicolor"] ;
    3 -> 4 ;
    5 [label="gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]\nclass = versicolor"] ;
    4 -> 5 ;
    6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
    4 -> 6 ;
    7 [label="petal width (cm) <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]\nclass = virginica"] ;
    3 -> 7 ;
    8 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]\nclass = virginica"] ;
    7 -> 8 ;
    9 [label="sepal length (cm) <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]\nclass = versicolor"] ;
    7 -> 9 ;
    10 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor"] ;
    9 -> 10 ;
    11 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
    9 -> 11 ;
    12 [label="petal length (cm) <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"] ;
    2 -> 12 ;
    13 [label="sepal length (cm) <= 5.95\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]\nclass = virginica"] ;
    12 -> 13 ;
    14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor"] ;
    13 -> 14 ;
    15 [label="gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]\nclass = virginica"] ;
    13 -> 15 ;
    16 [label="gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]\nclass = virginica"] ;
    12 -> 16 ;
    }
    

    This will have what you want. You can then easily write a program to parse this to handle as you want.

提交回复
热议问题