Extracting Information from the Decision Rules in rpart package

社会主义新天地 提交于 2021-02-16 20:58:13

问题


I need to extract information from the rules in decision tree. I am using rpart package in R. I am using demo data in the package to explain my requirements:

data(stagec)
fit<- rpart(formula = pgstat ~ age + eet + g2 + grade + gleason + ploidy, data = stagec, method = "class", control=rpart.control(cp=0.05))
fit

printing fit shows

n= 146 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 146 54 0 (0.6301370 0.3698630)  
   2) grade< 2.5 61  9 0 (0.8524590 0.1475410) *
   3) grade>=2.5 85 40 1 (0.4705882 0.5294118)  
     6) g2< 13.2 40 17 0 (0.5750000 0.4250000)  
      12) ploidy=diploid,tetraploid 31 11 0 (0.6451613 0.3548387) *
      13) ploidy=aneuploid 9  3 1 (0.3333333 0.6666667) *
     7) g2>=13.2 45 17 1 (0.3777778 0.6222222)  
      14) g2>=17.91 22  8 0 (0.6363636 0.3636364) *
      15) g2< 17.91 23  3 1 (0.1304348 0.8695652) *

e.g. I would like to get information something like below for the 12th node

If grade>=2.5 and g2< 13.2 and ploidy in (diploid,tetraploid) then class 0 is predicted with 65% confidence. Any pointers on this would be very helpful.

Thanks


回答1:


The rpart.plot package version 3.0 (July 2018) has a function rpart.rules for generating a set of rules for a tree. For example

library(rpart.plot)
data(stagec)
fit <- rpart(formula = pgstat ~ ., data = stagec, method = "class", control=rpart.control(cp=0.05))
rpart.rules(fit)

gives

pgstat                                                                   
  0.15 when grade <  3                                                   
  0.35 when grade >= 3 & g2 <  13       & ploidy is diploid or tetraploid
  0.36 when grade >= 3 & g2 >=       18                                  
  0.67 when grade >= 3 & g2 <  13       & ploidy is             aneuploid
  0.87 when grade >= 3 & g2 is 13 to 18 

And

rpart.rules(fit, roundint=FALSE, clip.facs=TRUE)

gives

pgstat                                                           
  0.15 when grade <  2.5                                         
  0.35 when grade >= 2.5 & g2 <  13       & diploid or tetraploid
  0.36 when grade >= 2.5 & g2 >=       18                        
  0.67 when grade >= 2.5 & g2 <  13       & aneuploid
  0.87 when grade >= 2.5 & g2 is 13 to 18                        

For more examples see Chapter 4 of the rpart.plot vignette.




回答2:


You can use the list.rules.party() function from the partykit package and a little bit of string formatting. Here is an example using your code.

data(stagec)
fit <- rpart(
  formula = pgstat ~ age + eet + g2 + grade + gleason + ploidy,
  data = stagec,
  method = "class",
  control = rpart.control(cp = 0.05)
)

party_obj <- as.party.rpart(fit, data = TRUE)
decisions <- partykit:::.list.rules.party(party_obj)
cat(paste(decisions, collapse = "\n"))

As you can see, you build your tree model the same way. Then you transform your model into a party object and use the list.rules.party() function to extract the decision strings. A little bit of formatting and you get

grade < 2.5
grade >= 2.5 & g2 < 13.2 & ploidy %in% c("diploid", "tetraploid")
grade >= 2.5 & g2 < 13.2 & ploidy %in% c("aneuploid")
grade >= 2.5 & g2 >= 13.2 & g2 >= 17.91
grade >= 2.5 & g2 >= 13.2 & g2 < 17.91

as the result.



来源:https://stackoverflow.com/questions/36401411/extracting-information-from-the-decision-rules-in-rpart-package

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!