I\'m trying to use cforest function(R, party package).
This\'s what I do to construct forest:
library(\"party\")
set.seed(42)
readingSkills.cf <-
The solution proposed by @rcs in the Update is interesting but does not work with cforest
when the dependent variable is numerical. The code:
set.seed(12345)
y <- cforest(score ~ ., data = readingSkills,
control = cforest_unbiased(mtry = 2, ntree = 50))
tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
tr_weights <- update_tree(tr)
plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
generates the following error message
R> Error in valid.data(rep(units, length.out = length(x)), data) :
no string supplied for 'strwidth/height' unit
and the following plot:
Below I suggest an improved version of the hack proposed by @rcs:
get_cTree <- function(cf, k=1) {
dt <- cf@data@get("input")
tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
tr_updated <- update_tree(tr, dt)
new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses,
cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
}
update_tree <- function(x, dt) {
x <- update_weights(x, dt)
if(!x$terminal) {
x$left <- update_tree(x$left, dt)
x$right <- update_tree(x$right, dt)
}
x
}
update_weights <- function(x, dt) {
splt <- x$psplit
spltClass <- attr(splt,"class")
spltVarName <- splt$variableName
spltVar <- dt[,spltVarName]
spltVarLev <- levels(spltVar)
if (!is.null(spltClass)) {
if (spltClass=="nominalSplit") {
attr(x$psplit$splitpoint,"levels") <- spltVarLev
filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)]
} else {
filt <- (spltVar <= splt$splitpoint)
}
x$left$weights <- as.numeric(filt)
x$right$weights <- as.numeric(!filt)
}
x
}
plot(get_cTree(y, 1))