Changing colors for decision tree plot created using export graphviz

自闭症网瘾萝莉.ら 提交于 2019-12-04 08:17:22
  • You can get a list of all the edges via graph.get_edge_list()
  • Each source node should have two target nodes, the one with the lower index is the evaluated as True, the higher index as False
  • Colors can be assigned via set_fillcolor()

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data)

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

Also, i've seen some trees where the length of the lines connecting nodes is proportional to the % varriance explained by the split. I'd love to be able to do that too if possible!?

You could play with set_weight() and set_len() but that's a bit more tricky and needs some fiddling to get it right but here is some code to get you started.

for edge in edges:
    edges[edge].sort()
    src = graph.get_node(edge)[0]
    total_weight = int(src.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        weight = int(dest.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
        graph.get_edge(edge, str(edges[edge][0]))[0].set_weight((1 - weight / total_weight) * 100)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_len(weight / total_weight)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_minlen(weight / total_weight)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!