Get decision tree rule/path pattern for every row of predicted dataset for rpart/ctree package in R

故事扮演 提交于 2019-12-28 02:57:27

问题


I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabilities and classes.

However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.

Which means I want to got following.

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

Kind of code I'm looking for is

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

this is what needs to expanded into multiple lines possibly

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")

回答1:


The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

To obtain the predictions you describe above you can do:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

Illustration using the iris data and rpart():

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

And with ctree():

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7


来源:https://stackoverflow.com/questions/29618490/get-decision-tree-rule-path-pattern-for-every-row-of-predicted-dataset-for-rpart

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!