- 决策树
决策树是数据挖掘领域中的常用模型,其基本思想是对预测变量进行二元分离,从而构造一颗可用于预测新样本单元所属类别的树
- 经典决策树
针对乳腺癌数据集中的良性/恶性,和一组预测变量对应9个细胞特征为基础
(1) 选定一个最佳预测变量将全部样本单元分为两类,实现两类中的纯度最大化(即一类中良性样本单元尽可能多,另一类中恶性样本尽可能的多)
a、如果预测变量连续,则选定一个分割点进行分类,使得两类纯度最大化;
b、如果预测变量为分类变量,则对各类别进行合并再分类
(2)对每一个子类分别继续执行比步骤(1)
(3)重复步骤(1)~(2),直到子类别中所含的样本单元数过少,或者没有分类法能将不纯度下线到一个给定阈值以下,最终集中的子类别即终端节点(terminal node)。根据每一个终端节点中样本单元数众数来判别这一终端节点的属性类别
(4)对任一样本单元执行决策树,得到其终端节点,即可根据步骤3得到模型预测的所属类别,
不过,上述算法通常会得到一棵过大的树,从而出现过度 拟合现象,导致对于训练集外单元的分类性能较差,可用 10折交叉验证法,这一 剪枝后的树用于预测。
R中的 repart() 函数构造决策树,prune() 函数对决策树进行减枝
- 创建决策树
#使用rpart()函数创建分类决策树
> library(rpart)
> set.seed(1234)
> dtree <- rpart(class ~.,data = df.train,method="class",parms = list(split="information")) #生成树
- 设定最终树的大小
rpart()返回的cptable值中包括不同大小的树对应的预测误差,因此可用于辅助设定最终的树的大小
a、复杂度参数(cp)用于惩罚过大的树
b、树的大小即分支数(nsplit),有 n 个分支的树将有 n+1 个终端节点
c、rel eeror栏中即各种树对应的误差
d、xerror即基于训练样本所得的10折交叉验证误差
e、xstd栏为交叉验证误差的标准差
> dtree$cptable
CP nsplit rel error xerror xstd
1 0.800000 0 1.00000 1.00000 0.06484605
2 0.046875 1 0.20000 0.30625 0.04150018
3 0.012500 3 0.10625 0.20625 0.03467089
4 0.010000 4 0.09375 0.18125 0.03264401
plotcp()函数可画出交叉验证误差与复杂度参数的关系图,对于所有交叉验证误差在最小交叉验证误差一个标准差范围内的树,最小的树即最优的树
本例子中 最小交叉误差(xerror)为0.18125,它标准差(xstd)为0.03264401,所以最优的树在 0.18 +- 0.0326 (0.15和0.21)之间的树
> plotcp(dtree) #绘制交叉验证误差与复杂度参数的关系图,借助关系图可以确定最优的树

图1 复杂度参数
复杂度参数与交叉验证误差,虚线是基于一个标准差准则得到上限(0.18+1*0.0326=0.21),从图像来卡,应选择虚线最左侧 cp 值对应的树
由 cptable 的结果可知,四个终端节点(即三次分割)的树满足要求(交叉验证误差为0.20625),根据图可以选的最优树,即三次分隔(四个节点)对应的树
- prune()剪枝
prune()函数根据复杂度参数剪掉最不重要的枝,从而将树的大小控制在理想范围内,从上述cptale中得到,三次分割对应的复杂度参数为0.0125,所以剪枝得到一个理想大小的树
prune(dtree,cp = 0.0125)
- prp() 画出最终的决策树
prp()z中有很多的参数(详见?prp),
type=2:画出每个节点下分隔的标签
extra = 104:画出每一类的概率以及每个节点处的样本占比
fallen.leaves:可在图的底端显示终端节点
对观测点分类时,从树的顶端开始,若满足条件则从左枝往下,否则从右枝往下,重复这个过程知道碰到一个终端节点为止,该终端节点即为这一观测点的所属类别
> dtree.pruned <- prune(dtree, cp=.0125)
> library(rpart.plot) #导入rpart.polt包中的prp()函数
> prp(dtree.pruned, type = 2, extra = 104,
+ fallen.leaves = TRUE, main="Decision Tree")

图2 用剪枝后的传统决策树预测癌症状态,从树的顶端开始如果条件成立则从左枝往下,否者从右枝往下,当观测点到达终端节点时,分类结束。每一个节点处都有对应类别的概率以及样本单元的占比
- predict()函数来
predict()函数用来验证集中的观测点分类
> dtree.pred <- predict(dtree.pruned, df.validate, type="class")
> dtree.perf <- table(df.validate$class, dtree.pred, #实际类别与预测类别的交叉表
+ dnn=c("Actual", "Predicted"))
> dtree.perf
Predicted
Actual benign malignant
benign 122 7
malignant 2 79
从整体来看,验证集中的准确率达到96%,与逻辑回归不同的是,验证集中的210个样本单元都可由最终树来分类。
值得注意的是,对于水平数很多或者缺失值很多的预测变量决策树可能会有偏
来源:oschina
链接:https://my.oschina.net/u/1785519/blog/1565245