python CART决策树API简介与示例

时光毁灭记忆、已成空白 提交于 2020-01-20 02:46:53

Cart(Classification And Regression Tree)决策树作为决策树算法分支的一条,是一类可用于分类与回归的非线性模型。其在python中的接口为from sklearn.tree import DecisionTreeClassifier

首先查看DecisionTreeClassifier都有哪些参数
DecisionTreeClassifier?
Init signature: 
DecisionTreeClassifier(criterion='gini', 
						splitter='best', 
						max_depth=None, 
						min_samples_split=2, 
						min_samples_leaf=1, 
						min_weight_fraction_leaf=0.0, 
						max_features=None, 
						random_state=None, 
						max_leaf_nodes=None, 
						min_impurity_decrease=0.0, 
						min_impurity_split=None, 
						class_weight=None, 
						presort='deprecated', 
						ccp_alpha=0.0)
						
重要参数:
criterion 为CART节点分割函数,默认为gini,当然也可以使用信息熵"entropy"
min_samples_split 分割节点所需的最小样本数
min_samples_leaf 每个叶子节点最小的样本数
min_impurity_decrease gini的分裂阈值,gini大于阈值继续分裂
ccp_alpha 裁剪阈值
python实例,使用CART对mnist数据集进行分类
import sklearn
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt

def main():
    digit = load_digits()
    data = digit.data
    label = digit.target
    #划分数据集
    train_x, test_x, train_y, test_y = train_test_split(data, label, test_size = 0.3, random_state = 0)
    # print(digit.target[0])
    # plt.imshow(digit.images[0],cmap='gray')
    # plt.show()

    # 数据预处理——规范化 standardScaler() 标准化 -均值/方差
    train_x = preprocessing.StandardScaler().fit_transform(train_x)
    test_x = preprocessing.StandardScaler().fit_transform(test_x)

    #训练 CART
    CART = DecisionTreeClassifier()
    CART.fit(train_x, train_y)
    print(CART)
    predicted = CART.predict(test_x)
    print(metrics.classification_report(test_y, predicted))
    print(metrics.confusion_matrix(test_y, predicted))
    print("CART 准确率:%0.6lf" % metrics.accuracy_score(test_y, predicted)

if __name__ == '__main__':
    main()
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!