Cart(Classification And Regression Tree)决策树作为决策树算法分支的一条,是一类可用于分类与回归的非线性模型。其在python中的接口为from sklearn.tree import DecisionTreeClassifier
首先查看DecisionTreeClassifier都有哪些参数
DecisionTreeClassifier?
Init signature:
DecisionTreeClassifier(criterion='gini',
splitter='best',
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features=None,
random_state=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
class_weight=None,
presort='deprecated',
ccp_alpha=0.0)
重要参数:
criterion 为CART节点分割函数,默认为gini,当然也可以使用信息熵"entropy"
min_samples_split 分割节点所需的最小样本数
min_samples_leaf 每个叶子节点最小的样本数
min_impurity_decrease gini的分裂阈值,gini大于阈值继续分裂
ccp_alpha 裁剪阈值
python实例,使用CART对mnist数据集进行分类
import sklearn
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt
def main():
digit = load_digits()
data = digit.data
label = digit.target
#划分数据集
train_x, test_x, train_y, test_y = train_test_split(data, label, test_size = 0.3, random_state = 0)
# print(digit.target[0])
# plt.imshow(digit.images[0],cmap='gray')
# plt.show()
# 数据预处理——规范化 standardScaler() 标准化 -均值/方差
train_x = preprocessing.StandardScaler().fit_transform(train_x)
test_x = preprocessing.StandardScaler().fit_transform(test_x)
#训练 CART
CART = DecisionTreeClassifier()
CART.fit(train_x, train_y)
print(CART)
predicted = CART.predict(test_x)
print(metrics.classification_report(test_y, predicted))
print(metrics.confusion_matrix(test_y, predicted))
print("CART 准确率:%0.6lf" % metrics.accuracy_score(test_y, predicted)
if __name__ == '__main__':
main()
来源:CSDN
作者:Bug永流传
链接:https://blog.csdn.net/qq_33590385/article/details/104042866