graph.write_pdf(“iris.pdf”) AttributeError: 'list' object has no attribute 'write_pdf'

前端 未结 10 1787
深忆病人
深忆病人 2020-12-08 19:00

My code is follow the class of machine learning of google.The two code are same.I don\'t know why it show error.May be the type of variable is error.But google\'s code is sa

相关标签:
10条回答
  • 2020-12-08 19:21

    I think you are using newer version of python. Please try with pydotplus.

    import pydotplus
    ...
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("iris.pdf")
    

    This should do it.

    0 讨论(0)
  • 2020-12-08 19:21

    I hope this helps, I was having a similar issue. I decided not to use pydot / pydotplus, but rather graphviz. I modified (barely) the code and it works wonders! :)

    # 2. Train classifier
    # Testing Data
    # Examples used to "test" the classifier's accuracy
    # Not part of the training data
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn import tree
    iris = load_iris()
    test_idx = [0, 50, 100] # Grabs one example of each flower for testing data (in the data set it so happens to be that
                            # each flower begins at 0, 50, and 100
    
    # training data
    train_target = np.delete(iris.target, test_idx)     # Delete all but 3 for training target data
    train_data = np.delete(iris.data, test_idx, axis=0) # Delete all but 3 for training data
    
    # testing data
    test_target = iris.target[test_idx] # Get testing target data
    test_data = iris.data[test_idx]     # Get testing data
    
    # create decision tree classifier and train in it on the testing data
    clf = tree.DecisionTreeClassifier()
    clf.fit(train_data, train_target)
    
    # Predict label for new flower
    print(test_target)
    print(clf.predict(test_data))
    
    # Visualize the tree
    from sklearn.externals.six import StringIO
    import graphviz
    dot_data = StringIO()
    tree.export_graphviz(clf,
            out_file=dot_data,
            feature_names=iris.feature_names,
            class_names=iris.target_names,
            filled=True, rounded=True,
            impurity=False)
    graph = graphviz.Source(dot_data.getvalue())
    graph.render("iris.pdf", view=True)
    
    0 讨论(0)
  • 2020-12-08 19:27

    To add all graphs for the number of your n_estimators you can do:

    for i in range(0, n):  #n is your n_estimators number
        dot_data = StringIO()
        tree.export_graphviz(clf.estimators_[i], out_file=dot_data, feature_names=iris.feature_names,
                            class_names=iris.target_names, filled=True, rounded=True,
                            impurity=False)
        graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
        graph.write_pdf("iris%s.pdf"%i)
    

    you could also switch the line

    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    

    for this one

    (graph,) = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("iris.pdf")
    

    and it would still work.

    0 讨论(0)
  • 2020-12-08 19:31

    I install scikit-learn via conda and all of about not work. Firstly, I have to install libtool

    brew install libtool --universal
    

    Then I follow this sklearn guide Then change the python file to this code

    clf = clf.fit(train_data, train_target)
    tree.export_graphviz(clf,out_file='tree.dot') 
    

    Finally convert to png in terminal

    dot -Tpng tree.dot -o tree.png
    
    0 讨论(0)
提交回复
热议问题