问题
I have an output from ctree() (party package) that looks like the following. How do I get the list of splitting conditions for each terminal node, like like sns <= 0, dta <= 1; sns <= 0, dta > 1 and so on?
1) sns <= 0; criterion = 1, statistic = 14655.021
2) dta <= 1; criterion = 1, statistic = 3286.389
3)* weights = 153682
2) dta > 1
4)* weights = 289415
1) sns > 0
5) dta <= 2; criterion = 1, statistic = 1882.439
6)* weights = 245457
5) dta > 2
7) dta <= 6; criterion = 1, statistic = 1170.813
8)* weights = 328582
7) dta > 6
Thanks
回答1:
This function should do the job
CtreePathFunc <- function (ct, data) {
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct))){
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes){
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
}
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
for(i in 1:length(Path)){
if(i == length(Path)) {
n <- nodes(ct, Node)[[1]]
} else {n <- nodes(ct, Path[i + 1])[[1]]}
if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
SB <- "<="
} else {SB <- ">"}
Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
SB,
as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
collapse = ", ")
}
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
}
return(ResulTable)
}
Testing
library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result
## Node Path
## 1 5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2 3 Temp <= 82, Wind <= 6.9
## 3 6 Temp <= 82, Wind > 6.9, Temp > 77
## 4 9 Temp > 82, Wind > 10.3
## 5 8 Temp > 82, Wind <= 10.3
回答2:
If you use the new recommended partykit implementation of ctree() rather than the old party package, then you can use the function .list.rules.party(). This is not yet officially exported, yet, but can be leveraged to extract the desired information.
library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
## 3 5
## "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77"
## 6 8
## "Temp <= 82 & Wind > 6.9 & Temp > 77" "Temp > 82 & Wind <= 10.3"
## 9
## "Temp > 82 & Wind > 10.3"
回答3:
Due I needed this function but for categorical data, I make, more or less answering the question @JoãoDaniel (I've only tested with categorical predictor variables), the next functions:
# returns string w/o leading or trailing whitespace
# http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r
trim <- function (x) gsub("^\\s+|\\s+$", "", x)
getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)
getReglaFinal = function(elemento) {
x = as.data.frame(strsplit(as.character(elemento),";"))
Regla = apply(x,1, trim)
Regla = data.frame(Regla)
indice = as.numeric(rownames(Regla))
variable = apply(Regla,1, getVariable)
simbolo = apply(Regla,1, getSimbolo)
ReglaRaw = data.frame(Regla,indice,variable,simbolo)
cols <- c( 'variable' , 'simbolo' )
ReglaRaw$tipo_corte <- apply( ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
#print(ReglaRaw)
cortes = unique(ReglaRaw$tipo_corte)
#print(cortes)
ReglaFinal = ""
for(i in 1:length(cortes)){
#print("------------------------------------")
#print(cortes[i])
#print("ReglaRaw econtrada")
#print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
#print(maximo)
tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
if(ReglaFinal==""){
ReglaFinal = tmp
}else{
ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
}
}
return(ReglaFinal)
}#getReglaFinal
CtreePathFuncAllCat <- function (ct) {
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct))){
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes){
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
}
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
variablesNombres <- array()
variablesPuntos <- list()
for(i in 1:length(Path)){
n <- nodes(ct, Path[i])[[1]]
if(i == length(Path)) {
nextNodeID = Node
} else {
nextNodeID = Path[i+1]
}
vec_puntos = as.vector(n[[5]]$splitpoint)
vec_nombre = n[[5]]$variableName
vec_niveles = attr(n[[5]]$splitpoint,"levels")
index = 0
if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){
index = vec_puntos
vec_puntos = vector(length=length(vec_niveles))
vec_puntos[index] = TRUE
}
if(length(vec_niveles)==0){
index = vec_puntos
vec_puntos = n[[5]]$splitpoint
}
if(index==0){
if(nextNodeID==n$right$nodeID){
vec_puntos = !vec_puntos
}else{
vec_puntos = !!vec_puntos
}
if(i != 1) {
for(j in 1:(length(Path)-1)){
if(length(variablesNombres)>=j){
if( variablesNombres[j]==vec_nombre){
vec_puntos = vec_puntos*variablesPuntos[[j]]
}
}
}
vec_puntos = vec_puntos==1
}
SB = "="
}else{
if(nextNodeID==n$right$nodeID){
SB = ">"
}else{
SB = "<="
}
}
variablesPuntos[[i]] = vec_puntos
variablesNombres[i] = vec_nombre
if(length(vec_niveles)==0){
descripcion = vec_puntos
}else{
descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
}
Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")
),
collapse = "; ")
}
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
}
we = weights(ct)
c0 = as.matrix(where(ct))
c3 = sapply(we, function(w) sum(w))
c3 = as.matrix(unique(cbind(c0,c3)))
Counts = as.matrix(c3[,2])
c2 = drop(Predict(ct))
Means = as.matrix(unique(c2))
ResulTable = data.frame(ResulTable,Means,Counts)
ResulTable = ResulTable[ order(ResulTable$Means) ,]
ResulTable$TruePath = apply(as.data.frame(ResulTable$Path),1, getReglaFinal)
ResulTable2 = ResulTable
ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")
cols <- c( 'SQL' , 'Node' )
ResulTable2$SQL <- apply( ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )
ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'", paste(ResulTable2$SQL,"'")))
ResultadoFinal = list()
ResultadoFinal$PreTable = ResulTable
ResultadoFinal$Table = ResulTable
ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
ResultadoFinal$Table$TruePath = NULL
ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")
return(ResultadoFinal)
}#CtreePathFuncAllCat
Here is a test:
library(party)
#With ordered factors
TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
Result2 <- CtreePathFuncAllCat(TreeModel1)
Result2
##$PreTable
## Node Path Means Counts
##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114
##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175
##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105
##4 3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333 18
## TruePath
##3 DECT > { Somewhat likely }; SYMPT > { Disagree }
##2 DECT > { Somewhat likely }; SYMPT <= { Disagree }
##1 DECT <= { Somewhat likely }; DECT > { Not likely }
##4 DECT <= { Not likely }
##
##$Table
## Node Path Means Counts
##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114
##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175
##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105
##4 3 DECT <= { Not likely } 9.833333 18
##
##$SQL
##[1] " CASE WHEN DECT > ('Somewhat likely') AND SYMPT > ('Disagree') THEN 'Nodo 7' WHEN DECT > ('Somewhat likely') AND SYMPT <= ('Disagree') THEN 'Nodo 6' WHEN DECT <= ('Somewhat likely') AND DECT > ('Not likely') THEN 'Nodo 4' WHEN DECT <= ('Not likely') THEN 'Nodo 3' END "
#With unordered factors
TreeModel2 = ctree(count~spray, data = InsectSprays)
plot(TreeModel2, type="simple")
Result2 <- CtreePathFuncAllCat(TreeModel2)
Result2
##$PreTable
##Node Path Means Counts TruePath
##2 5 spray = { C, D, E }; spray = { C, E } 2.791667 24 spray = { C, E }
##3 4 spray = { C, D, E }; spray = { D } 4.916667 12 spray = { D }
##1 2 spray = { A, B, F } 15.500000 36 spray = { A, B, F }
##
##$Table
##Node Path Means Counts
##2 5 spray = { C, E } 2.791667 24
##3 4 spray = { D } 4.916667 12
##1 2 spray = { A, B, F } 15.500000 36
##
##$SQL
##[1] " CASE WHEN spray = ('C','E') THEN 'Nodo 5' WHEN spray = ('D') THEN 'Nodo 4' WHEN spray = ('A','B','F') THEN 'Nodo 2' END "
#With continuous variables
airq <- subset(airquality, !is.na(Ozone))
TreeModel3 <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result2 <- CtreePathFuncAllCat(TreeModel3)
Result2
##$PreTable
## Node Path Means Counts
##1 5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917 48
##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21
##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7
##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10
##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30
## TruePath
##1 Temp <= { 77 }; Wind > { 6.9 }
##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }
##4 Temp > { 82 }; Wind > { 10.3 }
##2 Temp <= { 82 }; Wind <= { 6.9 }
##5 Temp > { 82 }; Wind <= { 10.3 }
##
##$Table
## Node Path Means Counts
##1 5 Temp <= { 77 }; Wind > { 6.9 } 18.47917 48
##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21
##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7
##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10
##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30
##
##$SQL
##[1] " CASE WHEN Temp <= (77) AND Wind > (6.9) THEN 'Nodo 5' WHEN Temp <= (82) AND Wind > (6.9) AND Temp > (77) THEN 'Nodo 6' WHEN Temp > (82) AND Wind > (10.3) THEN 'Nodo 9' WHEN Temp <= (82) AND Wind <= (6.9) THEN 'Nodo 3' WHEN Temp > (82) AND Wind <= (10.3) THEN 'Nodo 8' END "
Update! Now the function supports mix of categorical and numerical variables!
回答4:
The CtreePathFunc function rewritten in more of a Hadley-verse (and I think more comprehensible) way. Also handling categorical variables.
library(magrittr)
readSplitter <- function(nodeSplit){
splitPoint <- nodeSplit$splitpoint
if("levels" %>% is_in(splitPoint %>% attributes %>% names)){
splitPoint %>% attr("levels") %>% .[splitPoint]
}else{
splitPoint %>% as.numeric
}
}
hasWeigths <- function(ct, path, terminalNode, pathNumber){
ct %>%
nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>%
.[[1]] %>% use_series("weights") %>% as.logical %>% which
}
dataFilter <- function(ct, dts, path, terminalNode, pathNumber){
whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber)
nodes(ct, path[pathNumber])[[1]][[5]] %>%
buildDataFilter(dts, whichWeights)
}
buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter")
buildDataFilter.nominalSplit <-
function(nodeSplit, dts, whichWeights){
varName <- nodeSplit$variableName
includedLevels <- dts[ whichWeights
,varName] %>% unique
paste( varName, "=="
,includedLevels %>% paste(collapse = ", ") %>% paste0("{", ., "}"))
}
buildDataFilter.orderedSplit <-
function(nodeSplit, dts, whichWeights){
varName <- nodeSplit$variableName
splitter <- nodeSplit %>% readSplitter
dts[ whichWeights
,varName] %>%
is_weakly_less_than(splitter) %>%
all %>%
ifelse("<=" ,">") %>%
paste(varName, ., splitter)
}
readTerminalNodePaths <- function (ct, dts) {
nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights
sgmnts <- ct %>% where %>% unique
nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1
# Take the inner nodes smaller than the selected terminal node
innerNodes <-
function(Node) setdiff( 1:(Node - 1)
,sgmnts[sgmnts < Node])
pathForTerminalNode <- function(terminalNode){
innerNodes(terminalNode) %>%
sapply(function(innerNode){
if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode
}) %>%
unlist
}
# Find the splits criteria
sgmnts %>% sapply(function(terminalNode){ #
path <- terminalNode %>% pathForTerminalNode
path %>% length %>% seq %>%
sapply(function(nodeNumber){
dataFilter(ct, dts, path, terminalNode, nodeNumber)
}, simplify = FALSE) %>%
unlist %>% paste(collapse = " & ") %>%
data.frame(Node = terminalNode, Path = .)
}, simplify = FALSE) %>%
Reduce(f = rbind)
}
Testing
shiftFirstPart <- function(vctr, divideBy, proportion = .5){
vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy)
vctr
}
set.seed(11)
n <- 13000
gdt <-
data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes"))
,age = runif(n) %>% shiftFirstPart(1.5) %>%
cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high"))
,city = runif(n) %>% shiftFirstPart(1.5) %>%
cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis"))
,point = runif(n) %>% shiftFirstPart(1.2)
)
gct <- ctree( is_buyer ~ ., data = gdt)
readTerminalNodePaths(gct, gdt)
来源:https://stackoverflow.com/questions/21443203/ctree-how-to-get-the-list-of-splitting-conditions-for-each-terminal-node