How to draw tree diagram from CHAID Tree output

筅森魡賤 提交于 2020-02-07 03:35:11

问题


i'm trying to draw a tree diagram from a CHAID output. There are lots of examples out there but they only seem to work with binary splits. The tree i'm producing is a CHAID tree with more than 1 splits.

I've tried a few solutions but they always show the output as text as opposed to producing a diagram of the tree:

from CHAID import Tree
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
import random
from io import StringIO
def to_graphviz(self, filename=None, shape='circle', graph='digraph'):
        """Exports the tree in the dot format of the graphviz software"""
        nodes, connections = [], []
        if self.nodes:

            for n in self.expand_tree(mode=self.WIDTH):
                nid = self[n].identifier
                state = '"{0}" [label="{1}", shape={2}]'.format(
                    nid, self[n].tag, shape)
                nodes.append(state)

                for c in self.children(nid):
                    cid = c.identifier
                    connections.append('"{0}" -> "{1}"'.format(nid, cid))

        # write nodes and connections to dot format
        is_plain_file = filename is not None
        if is_plain_file:
            f = codecs.open(filename, 'w', 'utf-8')
        else:
            f = StringIO()

        f.write(graph + ' tree {\n')
        for n in nodes:
            f.write('\t' + n + '\n')

        if len(connections) > 0:
            f.write('\n')

        for c in connections:
            f.write('\t' + c + '\n')

        f.write('}')

        if not is_plain_file:
            print(f.getvalue())

        f.close()
data = ["a", "b", "c", "d"]
df=pd.DataFrame(([random.sample(data, 4) for _ in range(500)]))
df.columns=[['a','b','c','d']]
## set the CHAID input parameters
independent_variable_columns = ['a', 'b', 'c']
dep_variable = 'd'
tree = Tree.from_pandas_df(df, dict(zip(independent_variable_columns, ['nominal'] * 3)), dep_variable,min_child_node_size=0)
tree.print_tree()

this produces :

([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))
|-- (['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))
|   |-- (['b'], {'a': 0, 'b': 0, 'c': 16.0, 'd': 20.0}, <Invalid Chaid Split> - the max depth has been reached)
|   |-- (['c'], {'a': 0, 'b': 13.0, 'c': 0, 'd': 19.0}, <Invalid Chaid Split> - the max depth has been reached)
|   +-- (['d'], {'a': 0, 'b': 18.0, 'c': 18.0, 'd': 0}, <Invalid Chaid Split> - the max depth has been reached)
|-- (['b'], {'a': 41.0, 'b': 0, 'c': 51.0, 'd': 41.0}, (('a',), p=2.2336212592454823e-14, score=70.03309116568674, groups=[['a'], ['c'], ['d']]), dof=4)) etc

as opposed to display a diagram with boxes and labels etc.

I also tried :

treeG=tree.to_tree()
display(SVG(to_graphviz(treeG).pipe(format='svg')))

but this produces roughly the same thing:

digraph tree {
    "0" [label="([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))", shape=circle]
    "1" [label="(['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))", shape=circle] etc..

If someone would know how to turn this into a proper diagram it would be much appreciated.

Thank you + BR

来源:https://stackoverflow.com/questions/57711320/how-to-draw-tree-diagram-from-chaid-tree-output

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!