How to extract rules from decision tree spark MLlib

一曲冷凌霜 提交于 2019-12-01 16:12:49

问题


I am using Spark MLlib 1.4.1 to create decisionTree model. Now I want to extract rules from decision tree.

How can I extract rules ?


回答1:


You can get the full model as a string by calling model.toDebugString(), or save it as JSON by calling model.save(sc, filePath).

The documentation is here, which contains a example with a small sample data that you can inspect the output format in command line. Here I formatted the script that you can directly past and run.

from numpy import array
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree

data = [
LabeledPoint(0.0, [0.0]),
LabeledPoint(1.0, [1.0]),
LabeledPoint(1.0, [2.0]),
LabeledPoint(1.0, [3.0])
]

model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
print(model)

print(model.toDebugString())

the output is:

DecisionTreeModel classifier of depth 1 with 3 nodes
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0 

In real application, the model can be very large and consists many lines. So directly use dtModel.toDebugString() can cause IPython notebook to halt. So I suggest to out put it as a text file.

Here is an example code of how to export a model dtModel to text file. Suppose we get the dtModel like this:

dtModel = DecisionTree.trainClassifier(parsedTrainData, numClasses=7, categoricalFeaturesInfo={},impurity='gini', maxDepth=20, maxBins=24)



modelFile = ~/decisionTreeModel.txt"
f = open(modelFile,"w") 
f.write(dtModel.toDebugString())
f.close() 

Here is an example output of the above script from my dtMmodel:

DecisionTreeModel classifier of depth 20 with 20031 nodes
  If (feature 0 <= -35.0)
   If (feature 24 <= 176.0)
    If (feature 0 <= -200.0)
     If (feature 29 <= 109.0)
      If (feature 6 <= -156.0)
       If (feature 9 <= 0.0)
        If (feature 20 <= -116.0)
         If (feature 16 <= 203.0)
          If (feature 11 <= 163.0)
           If (feature 5 <= 384.0)
            If (feature 15 <= 325.0)
             If (feature 13 <= -248.0)
              If (feature 20 <= -146.0)
               Predict: 0.0
              Else (feature 20 > -146.0)
               If (feature 19 <= -58.0)
                Predict: 6.0
               Else (feature 19 > -58.0)
                Predict: 0.0
             Else (feature 13 > -248.0)
              If (feature 9 <= -26.0)
               Predict: 0.0
              Else (feature 9 > -26.0)
               If (feature 10 <= 218.0)
...
...
...
...



回答2:


import networkx as nx

Load the model data, this is present in hadoop if you have previously used model.save(location) at that location

modeldf = spark.read.parquet(location+"/data/*")

noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()

Creating a dummy feature array

features = ["feature"+str(i) for i in range(0,700)]

Initialize the graph

G = nx.DiGraph()
for rw in noderows:

    if rw['leftChild'] < 0 and rw['rightChild'] < 0:

        G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])

    else:

        G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])



for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():

    tempnode = G.nodes(data="True")[rw['id']][1]

    #print(tempnode)

    G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

    G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

The code above converts all the rules to a graph network. To print all the rules in if and else format, we can find path to all the leaf nodes, and list the edge reason to extract the final rules

nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]

for n in nodes:

    p = nx.shortest_path(G,0,n)

    print("Rule No:",n)

    print(" & ".join([G.get_edge_data(p[i],p[i+1])['reason'] for i in range(0,len(p)-1)]))

The output looks something like this:

('Rule No:', 5)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 less than [1.0]

('Rule No:', 8)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 less than [0.0]

('Rule No:', 9)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 greater than [0.0]

('Rule No:', 11)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 less than [0.0]

('Rule No:', 12)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 greater than [0.0]

('Rule No:', 16)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 less than [1.0]

('Rule No:', 17)

feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 greater than [1.0]

Modified the initial code present here



来源:https://stackoverflow.com/questions/31782288/how-to-extract-rules-from-decision-tree-spark-mllib

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!