cforest prints empty tree

前端 未结 2 1320
刺人心
刺人心 2021-01-03 04:10

I\'m trying to use cforest function(R, party package).

This\'s what I do to construct forest:

library(\"party\")
set.seed(42)
readingSkills.cf <-          


        
2条回答
  •  南方客
    南方客 (楼主)
    2021-01-03 04:27

    The solution proposed by @rcs in the Update is interesting but does not work with cforest when the dependent variable is numerical. The code:

    set.seed(12345)
    y <- cforest(score ~ ., data = readingSkills,
           control = cforest_unbiased(mtry = 2, ntree = 50))
    tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
    tr_weights <- update_tree(tr)
    plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
    

    generates the following error message

    R> Error in valid.data(rep(units, length.out = length(x)), data) :
       no string supplied for 'strwidth/height' unit
    

    and the following plot:

    Below I suggest an improved version of the hack proposed by @rcs:

    get_cTree <- function(cf, k=1) {
      dt <- cf@data@get("input")
      tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
      tr_updated <- update_tree(tr, dt)
      new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses, 
          cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
    }
    
    update_tree <- function(x, dt) {
      x <- update_weights(x, dt)
      if(!x$terminal) {
        x$left <- update_tree(x$left, dt)
        x$right <- update_tree(x$right, dt)   
      } 
      x
    }
    
    update_weights <- function(x, dt) {
      splt <- x$psplit
      spltClass <- attr(splt,"class")
      spltVarName <- splt$variableName
      spltVar <- dt[,spltVarName]
      spltVarLev <- levels(spltVar)
      if (!is.null(spltClass)) {
        if (spltClass=="nominalSplit") {
         attr(x$psplit$splitpoint,"levels") <- spltVarLev   
         filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)] 
        } else {
         filt <- (spltVar <= splt$splitpoint)
        }
      x$left$weights <- as.numeric(filt)
      x$right$weights <- as.numeric(!filt)
      }
      x
    }
    
    plot(get_cTree(y, 1))
    

提交回复
热议问题