问题
i'm trying to draw a tree diagram from a CHAID output. There are lots of examples out there but they only seem to work with binary splits. The tree i'm producing is a CHAID tree with more than 1 splits.
I've tried a few solutions but they always show the output as text as opposed to producing a diagram of the tree:
from CHAID import Tree
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
import random
from io import StringIO
def to_graphviz(self, filename=None, shape='circle', graph='digraph'):
"""Exports the tree in the dot format of the graphviz software"""
nodes, connections = [], []
if self.nodes:
for n in self.expand_tree(mode=self.WIDTH):
nid = self[n].identifier
state = '"{0}" [label="{1}", shape={2}]'.format(
nid, self[n].tag, shape)
nodes.append(state)
for c in self.children(nid):
cid = c.identifier
connections.append('"{0}" -> "{1}"'.format(nid, cid))
# write nodes and connections to dot format
is_plain_file = filename is not None
if is_plain_file:
f = codecs.open(filename, 'w', 'utf-8')
else:
f = StringIO()
f.write(graph + ' tree {\n')
for n in nodes:
f.write('\t' + n + '\n')
if len(connections) > 0:
f.write('\n')
for c in connections:
f.write('\t' + c + '\n')
f.write('}')
if not is_plain_file:
print(f.getvalue())
f.close()
data = ["a", "b", "c", "d"]
df=pd.DataFrame(([random.sample(data, 4) for _ in range(500)]))
df.columns=[['a','b','c','d']]
## set the CHAID input parameters
independent_variable_columns = ['a', 'b', 'c']
dep_variable = 'd'
tree = Tree.from_pandas_df(df, dict(zip(independent_variable_columns, ['nominal'] * 3)), dep_variable,min_child_node_size=0)
tree.print_tree()
this produces :
([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))
|-- (['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))
| |-- (['b'], {'a': 0, 'b': 0, 'c': 16.0, 'd': 20.0}, <Invalid Chaid Split> - the max depth has been reached)
| |-- (['c'], {'a': 0, 'b': 13.0, 'c': 0, 'd': 19.0}, <Invalid Chaid Split> - the max depth has been reached)
| +-- (['d'], {'a': 0, 'b': 18.0, 'c': 18.0, 'd': 0}, <Invalid Chaid Split> - the max depth has been reached)
|-- (['b'], {'a': 41.0, 'b': 0, 'c': 51.0, 'd': 41.0}, (('a',), p=2.2336212592454823e-14, score=70.03309116568674, groups=[['a'], ['c'], ['d']]), dof=4)) etc
as opposed to display a diagram with boxes and labels etc.
I also tried :
treeG=tree.to_tree()
display(SVG(to_graphviz(treeG).pipe(format='svg')))
but this produces roughly the same thing:
digraph tree {
"0" [label="([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))", shape=circle]
"1" [label="(['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))", shape=circle] etc..
If someone would know how to turn this into a proper diagram it would be much appreciated.
Thank you + BR
来源:https://stackoverflow.com/questions/57711320/how-to-draw-tree-diagram-from-chaid-tree-output