graph.write_pdf(“iris.pdf”) AttributeError: 'list' object has no attribute 'write_pdf'

前端 未结 10 1815
深忆病人
深忆病人 2020-12-08 19:00

My code is follow the class of machine learning of google.The two code are same.I don\'t know why it show error.May be the type of variable is error.But google\'s code is sa

10条回答
  •  谎友^
    谎友^ (楼主)
    2020-12-08 19:21

    I hope this helps, I was having a similar issue. I decided not to use pydot / pydotplus, but rather graphviz. I modified (barely) the code and it works wonders! :)

    # 2. Train classifier
    # Testing Data
    # Examples used to "test" the classifier's accuracy
    # Not part of the training data
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn import tree
    iris = load_iris()
    test_idx = [0, 50, 100] # Grabs one example of each flower for testing data (in the data set it so happens to be that
                            # each flower begins at 0, 50, and 100
    
    # training data
    train_target = np.delete(iris.target, test_idx)     # Delete all but 3 for training target data
    train_data = np.delete(iris.data, test_idx, axis=0) # Delete all but 3 for training data
    
    # testing data
    test_target = iris.target[test_idx] # Get testing target data
    test_data = iris.data[test_idx]     # Get testing data
    
    # create decision tree classifier and train in it on the testing data
    clf = tree.DecisionTreeClassifier()
    clf.fit(train_data, train_target)
    
    # Predict label for new flower
    print(test_target)
    print(clf.predict(test_data))
    
    # Visualize the tree
    from sklearn.externals.six import StringIO
    import graphviz
    dot_data = StringIO()
    tree.export_graphviz(clf,
            out_file=dot_data,
            feature_names=iris.feature_names,
            class_names=iris.target_names,
            filled=True, rounded=True,
            impurity=False)
    graph = graphviz.Source(dot_data.getvalue())
    graph.render("iris.pdf", view=True)
    

提交回复
热议问题