Find the data elements in a data frame that pass the rule for a node in a tree model?

前端 未结 2 1912
南旧
南旧 2021-01-03 06:04

So I have used the rpart package to create a tree model and I found an interesting rule and wondered if there was an easy way to see which observations in that data frame pa

相关标签:
2条回答
  • 2021-01-03 06:55
    #' subset of rpart node: return logical index
    #' @param tree rpart model
    #' @param node which node/leaf?
    #' @export
    subset_rpart <- function (tree, node) {
      nodes = as.numeric(rownames(tree$frame))
      nodes = log(nodes, 2)
      lower = log(node, 2)
      upper = log(node + 1, 2)
      a = floor(lower)
      lower_ = lower - a
      upper_  = upper - a
      nodes_ = nodes %% 1
      w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
      tree$where %in% w
    }
    
    
    
    #' subset df by subset_rpart
    #' @param tree rpart model
    #' @param node node number
    #' @param df df
    #' @export
    subset.rpart = function(tree, node, df){
      df[subset_rpart(tree, node), ]
    }
    
    0 讨论(0)
  • 2021-01-03 06:58

    I modified the code in path.rpart to return the subset of the data that falls within a particular node rather than returning information about that node. It works by either clicking on the plot or by passing nodes just as the path.rpart function does. Here is the code

    subset.rpart <- function (tree, df, nodes) {
        if (!inherits(tree, "rpart")) 
            stop("Not a legitimate \"rpart\" object")
        stopifnot(nrow(df)==length(tree$where))
        frame <- tree$frame
        n <- row.names(frame)
        node <- as.numeric(n)
    
        if (missing(nodes)) {
            xy <- rpart:::rpartco(tree)
            i <- identify(xy, n = 1L, plot = FALSE)
            if(i> 0L) {
                 return( df[tree$where==i, ] )
            } else {
                return(df[0,])
            }
        }
        else {
            if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
                return(df[0,])
            return ( df[tree$where %in% as.numeric(nodes), ] )
        }
    }
    

    I will use it on some sample data from the package

    fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
    plot(fit)
    text(fit)
    

    rpart tree plot

    And then to find the observations at a particular node, run

    subset.rpart(fit, kyphosis)
    

    and click on a node on the plot. After you do, all the observations at that node will be returned. You must use the same data.frame that was used for modeling for this to work properly. Rather than clicking on a point, you can also pass in a node name that you you discover with path.rpart

    # path.rpart(fit)  
    #  node number: 10  ---> looks interesting
    #    root
    #    Start>=8.5
    #    Start< 14.5
    #    Age< 55
    
    subset.rpart(fit, kyphosis, 10)
    #    Kyphosis Age Number Start
    # 14   absent   1      4    12
    # 20   absent  27      4     9
    # 26   absent   9      5    13
    # 37   absent   1      3     9
    # 39   absent  20      6     9
    # 42   absent  35      3    13
    # 57   absent   2      3    13
    # 59   absent  51      7     9
    # 66   absent  17      4    10
    # 69   absent  18      4    11
    # 78   absent  26      7    13
    # 81   absent  36      4    13
    
    0 讨论(0)
提交回复
热议问题