I have one question though. I heard from someone that in R, you can use extra packages to extract the decision rules implemented in RF, I try to google the same thing in python but without luck, if there is any help on how to achieve that. thanks in advance!
Assuming that you use sklearn RandomForestClassifier you can find the invididual decision trees as .estimators_
. Each tree stores the decision nodes as a number of NumPy arrays under tree_
.
Here is some example code which just prints each node in order of the array. In a typical application one would instead traverse by following the children.
import numpy
from sklearn.model_selection import train_test_split
from sklearn import metrics, datasets, ensemble
def print_decision_rules(rf):
for tree_idx, est in enumerate(rf.estimators_):
tree = est.tree_
assert tree.value.shape[1] == 1 # no support for multi-output
print('TREE: {}'.format(tree_idx))
iterator = enumerate(zip(tree.children_left, tree.children_right, tree.feature, tree.threshold, tree.value))
for node_idx, data in iterator:
left, right, feature, th, value = data
# left: index of left child (if any)
# right: index of right child (if any)
# feature: index of the feature to check
# th: the threshold to compare against
# value: values associated with classes
# for classifier, value is 0 except the index of the class to return
class_idx = numpy.argmax(value[0])
if left == -1 and right == -1:
print('{} LEAF: return class={}'.format(node_idx, class_idx))
else:
print('{} NODE: if feature[{}] < {} then next={} else next={}'.format(node_idx, feature, th, left, right))
digits = datasets.load_digits()
Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target)
estimator = ensemble.RandomForestClassifier(n_estimators=3, max_depth=2)
estimator.fit(Xtrain, ytrain)
print_decision_rules(estimator)
Example outout:
TREE: 0
0 NODE: if feature[33] < 2.5 then next=1 else next=4
1 NODE: if feature[38] < 0.5 then next=2 else next=3
2 LEAF: return class=2
3 LEAF: return class=9
4 NODE: if feature[50] < 8.5 then next=5 else next=6
5 LEAF: return class=4
6 LEAF: return class=0
...
We use something similar in emtrees to compile a Random Forest to C code.
来源:https://stackoverflow.com/questions/50600290/how-extraction-decision-rules-of-random-forest-in-python