Add Column of Predicted Values to Data Frame with dplyr

后端 未结 4 951
太阳男子
太阳男子 2020-12-14 23:47

I have a data frame with a column of models and I am trying to add a column of predicted values to it. A minimal example is :

exampleTable <- data.frame(x          


        
相关标签:
4条回答
  • 2020-12-15 00:19

    Using modelr, there is an elegant solution using the tidyverse.

    The inputs

    library(dplyr)
    library(purrr)
    library(tidyr)
    
    # generate the inputs like in the question
    example_table <- data.frame(x = c(1:5, 1:5),
                                y = c((1:5) + rnorm(5), 2*(5:1)),
                                groups = rep(LETTERS[1:2], each = 5))
    
    models <- example_table %>% 
      group_by(groups) %>% 
      do(model = lm(y ~ x, data = .)) %>%
      ungroup()
    example_table <- left_join(tbl_df(example_table ), models, by = "groups")
    

    The solution

    # generate the extra column
    example_table %>%
      group_by(groups) %>%
      do(modelr::add_predictions(., first(.$model)))
    

    The explanation

    add_predictions adds a new column to a data frame using a given model. Unfortunately it only takes one model as an argument. Meet do. Using do, we can run add_prediction individually over each group.

    . represents the grouped data frame, .$model the model column and first() takes the first model of each group.

    Simplified

    With only one model, add_predictions works very well.

    # take one of the models
    model <- example_table$model[[6]]
    
    # generate the extra column
    example_table %>%
      modelr::add_predictions(model)
    

    Recipes

    Nowadays, the tidyverse is shifting from the modelr package to recipes so that might be the new way to go once this package matures.

    0 讨论(0)
  • 2020-12-15 00:19

    Using the tidyverse:

    library(dplyr)
    library(purrr)
    library(tidyr)
    library(broom)
    
    exampleTable <- data.frame(
      x = c(1:5, 1:5),
      y = c((1:5) + rnorm(5), 2*(5:1)),
      groups = rep(LETTERS[1:2], each = 5)
    )
    
    exampleTable %>% 
      group_by(groups) %>%
      nest() %>% 
      mutate(model = data %>% map(~lm(y ~ x, data = .))) %>% 
      mutate(Pred = map2(model, data, predict)) %>% 
      unnest(Pred, data)
    
    # A tibble: 10 × 4
       groups      Pred     x          y
       <fctr>     <dbl> <int>      <dbl>
    1       A  1.284185     1  0.9305908
    2       A  1.909262     2  1.9598293
    3       A  2.534339     3  3.2812002
    4       A  3.159415     4  2.9283637
    5       A  3.784492     5  3.5717085
    6       B 10.000000     1 10.0000000
    7       B  8.000000     2  8.0000000
    8       B  6.000000     3  6.0000000
    9       B  4.000000     4  4.0000000
    10      B  2.000000     5  2.0000000
    
    0 讨论(0)
  • 2020-12-15 00:29

    Eh, this is only slightly better:

    answer = 
      exampleTable %>%
      group_by(groups) %>%
      do(lm( y ~ x , data = .) %>% 
           predict %>% 
           data_frame(prediction = .)) %>%
      bind_cols(exampleTable)
    

    I was hoping this would work but it didn't.

    answer = 
      exampleTable %>%
      group_by(groups) %>%
      mutate(prediction = 
               lm( y ~ x , data = .) %>% 
               predict)
    
    0 讨论(0)
  • 2020-12-15 00:31

    For the record, this is painless in data.table:

    library(data.table)
    setDT(exampleTable)
    
    # actually, the more typical usage is to set the newdata
    #   argument here to .SD (especially for multivariate regressions; see:
    #   https://stackoverflow.com/a/32277135/3576984
    exampleTable[ , estimates := predict(lm(y ~ x), data.frame(x)), by = groups]
    
    exampleTable
    #     x          y groups  estimates
    #  1: 1  0.3123549      A  0.6826629
    #  2: 2  2.7636593      A  1.8297796
    #  3: 3  1.7771181      A  2.9768963
    #  4: 4  5.2031623      A  4.1240130
    #  5: 5  4.8281869      A  5.2711297
    #  6: 1 10.0000000      B 10.0000000
    #  7: 2  8.0000000      B  8.0000000
    #  8: 3  6.0000000      B  6.0000000
    #  9: 4  4.0000000      B  4.0000000
    # 10: 5  2.0000000      B  2.0000000
    

    If you're sold on data.table's clarity as I was, check out the intro vignettes!

    Also, you don't really need to group by groups. Just include that as a dummy interaction. If I recall, that's the proper approach to get correct standard errors, anyway:

    exampleTable[ , estimates2 := predict(lm(y ~ x * factor(groups)), .SD)]
    exampleTable[ , all.equal(estimates, estimates2)]
    # [1] TRUE
    
    0 讨论(0)
提交回复
热议问题