How can I ensure that a partition has representative observations from each level of a factor?

前端 未结 1 1474
-上瘾入骨i
-上瘾入骨i 2020-12-09 23:01

I wrote a small function to partition my dataset into training and testing sets. However, I am running into trouble when dealing with factor variables. In the model valida

1条回答
  •  醉酒成梦
    2020-12-09 23:35

    Try the caret package, particularly the function createDataPartition(). It should do exactly what you need, available on CRAN, homepage is here:

    caret - data splitting

    The function I mentioned is partially some code I found a while back on net, and then I modified it slightly to better handle edge cases (like when you ask for a sample size larger than the set, or a subset).

    stratified <- function(df, group, size) {
      # USE: * Specify your data frame and grouping variable (as column
      # number) as the first two arguments.
      # * Decide on your sample size. For a sample proportional to the
      # population, enter "size" as a decimal. For an equal number
      # of samples from each group, enter "size" as a whole number.
      #
      # Example 1: Sample 10% of each group from a data frame named "z",
      # where the grouping variable is the fourth variable, use:
      #
      # > stratified(z, 4, .1)
      #
      # Example 2: Sample 5 observations from each group from a data frame
      # named "z"; grouping variable is the third variable:
      #
      # > stratified(z, 3, 5)
      #
      require(sampling)
      temp = df[order(df[group]),]
      colsToReturn <- ncol(df)
    
      #Don't want to attempt to sample more than possible
      dfCounts <- table(df[group])
      if (size > min(dfCounts)) {
        size <- min(dfCounts)
      }
    
    
    
      if (size < 1) {
        size = ceiling(table(temp[group]) * size)
      } else if (size >= 1) {
        size = rep(size, times=length(table(temp[group])))
      }
      strat = strata(temp, stratanames = names(temp[group]),
                     size = size, method = "srswor")
      (dsample = getdata(temp, strat))
    
      dsample <- dsample[order(dsample[1]),]
      dsample <- data.frame(dsample[,1:colsToReturn], row.names=NULL)
      return(dsample)
    
    }
    

    0 讨论(0)
提交回复
热议问题