Get decision tree rule/path pattern for every row of predicted dataset for rpart/ctree package in R

前端 未结 1 770
鱼传尺愫
鱼传尺愫 2020-12-03 08:35

I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabil

相关标签:
1条回答
  • 2020-12-03 09:38

    The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

    To obtain the predictions you describe above you can do:

    pathpred <- function(object, ...)
    {
      ## coerce to "party" object if necessary
      if(!inherits(object, "party")) object <- as.party(object)
    
      ## get standard predictions (response/prob) and collect in data frame
      rval <- data.frame(response = predict(object, type = "response", ...))
      rval$prob <- predict(object, type = "prob", ...)
    
      ## get rules for each node
      rls <- partykit:::.list.rules.party(object)
    
      ## get predicted node and select corresponding rule
      rval$rule <- rls[as.character(predict(object, type = "node", ...))]
    
      return(rval)
    }
    

    Illustration using the iris data and rpart():

    library("rpart")
    library("partykit")
    rp <- rpart(Species ~ ., data = iris)
    rp_pred <- pathpred(rp)
    rp_pred[c(1, 51, 101), ]
    ##       response prob.setosa prob.versicolor prob.virginica
    ## 1       setosa  1.00000000      0.00000000     0.00000000
    ## 51  versicolor  0.00000000      0.90740741     0.09259259
    ## 101  virginica  0.00000000      0.02173913     0.97826087
    ##                                           rule
    ## 1                          Petal.Length < 2.45
    ## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
    ## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75
    

    (Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

    And with ctree():

    ct <- ctree(Species ~ ., data = iris)
    ct_pred <- pathpred(ct)
    ct_pred[c(1, 51, 101), ]
    ##       response prob.setosa prob.versicolor prob.virginica
    ## 1       setosa  1.00000000      0.00000000     0.00000000
    ## 51  versicolor  0.00000000      0.97826087     0.02173913
    ## 101  virginica  0.00000000      0.02173913     0.97826087
    ##                                                              rule
    ## 1                                             Petal.Length <= 1.9
    ## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
    ## 101                        Petal.Length > 1.9 & Petal.Width > 1.7
    
    0 讨论(0)
提交回复
热议问题