cforest prints empty tree

前端 未结 2 1319
刺人心
刺人心 2021-01-03 04:10

I\'m trying to use cforest function(R, party package).

This\'s what I do to construct forest:

library(\"party\")
set.seed(42)
readingSkills.cf <-          


        
相关标签:
2条回答
  • 2021-01-03 04:27

    The solution proposed by @rcs in the Update is interesting but does not work with cforest when the dependent variable is numerical. The code:

    set.seed(12345)
    y <- cforest(score ~ ., data = readingSkills,
           control = cforest_unbiased(mtry = 2, ntree = 50))
    tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
    tr_weights <- update_tree(tr)
    plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
    

    generates the following error message

    R> Error in valid.data(rep(units, length.out = length(x)), data) :
       no string supplied for 'strwidth/height' unit
    

    and the following plot:

    Below I suggest an improved version of the hack proposed by @rcs:

    get_cTree <- function(cf, k=1) {
      dt <- cf@data@get("input")
      tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
      tr_updated <- update_tree(tr, dt)
      new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses, 
          cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
    }
    
    update_tree <- function(x, dt) {
      x <- update_weights(x, dt)
      if(!x$terminal) {
        x$left <- update_tree(x$left, dt)
        x$right <- update_tree(x$right, dt)   
      } 
      x
    }
    
    update_weights <- function(x, dt) {
      splt <- x$psplit
      spltClass <- attr(splt,"class")
      spltVarName <- splt$variableName
      spltVar <- dt[,spltVarName]
      spltVarLev <- levels(spltVar)
      if (!is.null(spltClass)) {
        if (spltClass=="nominalSplit") {
         attr(x$psplit$splitpoint,"levels") <- spltVarLev   
         filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)] 
        } else {
         filt <- (spltVar <= splt$splitpoint)
        }
      x$left$weights <- as.numeric(filt)
      x$right$weights <- as.numeric(!filt)
      }
      x
    }
    
    plot(get_cTree(y, 1))
    

    0 讨论(0)
  • 2021-01-03 04:32

    Short answer: the case weights weights in each node are NULL, i.e. not stored. The prettytree function outputs weights = 0, since sum(NULL) equals 0 in R.


    Consider the following ctree example:

    library("party")
    x <- ctree(Species ~ ., data=iris)
    plot(x, type="simple")
    

    ctree plot

    For the resulting object x (class BinaryTree) the case weights are stored in each node:

    R> sum(x@tree$left$weights)
    [1] 50
    R> sum(x@tree$right$weights)
    [1] 100
    R> sum(x@tree$right$left$weights)
    [1] 54
    R> sum(x@tree$right$right$weights)
    [1] 46
    

    Now lets take a closer look at cforest:

    y <- cforest(Species ~ ., data=iris, control=cforest_control(mtry=2))
    tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
    plot(new("BinaryTree", tree=tr, data=y@data, responses=y@responses))
    

    cforest tree

    The case weights are not stored in the tree ensemble, which can be seen by the following:

    fixInNamespace("print.TerminalNode", "party")
    

    change the print method to

    function (x, n = 1, ...)·                                                     
    {                                                                             
        print(names(x))                                                           
        print(x$weights)                                                          
        cat(paste(paste(rep(" ", n - 1), collapse = ""), x$nodeID,·               
            ")* ", sep = "", collapse = ""), "weights =", sum(x$weights),·        
            "\n")                                                                 
    } 
    

    Now we can observe that weights is NULL in every node:

    R> tr
    1) Petal.Width <= 0.4; criterion = 10.641, statistic = 10.641
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
      2)*  weights = 0 
    1) Petal.Width > 0.4
      3) Petal.Width <= 1.6; criterion = 8.629, statistic = 8.629
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        4)*  weights = 0 
      3) Petal.Width > 1.6
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        5)*  weights = 0 
    

    Update this is a hack to display the sums of the case weights:

    update_tree <- function(x) {
      if(!x$terminal) {
        x$left <- update_tree(x$left)
        x$right <- update_tree(x$right)
      } else {
        x$weights <- x[[9]]
        x$weights_ <- x[[9]]
      }
      x
    }
    tr_weights <- update_tree(tr)
    plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
    

    cforest tree with case weights

    0 讨论(0)
提交回复
热议问题