advice on Usage of dplyr:: do vs purrr: map, tidy::nest, for predictions

一世执手 提交于 2019-12-03 21:29:55

I'm really interested in finding out differences between the do and the nest, map approaches. Maybe people have tried both and they can comment in which is faster when dealing with much bigger datasets, or much more models.

So far I've been using the do approach as follows:

library(tidyverse)

# reproducible results
set.seed(47)

# shuffle / randomise rows
mtcars2 = mtcars %>% sample_frac(1)

# split train / test
mtcars_train = mtcars2[1:20,]
mtcars_test = mtcars2[21:32,]

# for each cyl group create subsets and fit the models of interest using do
dt_models = mtcars_train %>%
  group_by(cyl) %>%
  do(model1 = lm(disp ~ hp, data = .),
     model2 = lm(disp ~ mpg, data = .)) %>%
  ungroup %>%
  print()

# reshape model dataset (for easier use later)
dt_models = dt_models %>% gather("name","model", -cyl) %>% print()

# function to pick model and predict corresponding data (row)
GetModelAndPredict = function(input_cyl, model_name, dd){

  m = (dt_models %>% filter(cyl==input_cyl & name==model_name))$model[[1]]

  predict.lm(m, newdata=dd)

}

# predict each row using the corresponding model
mtcars_test %>%
  rowwise() %>%
  do(data.frame(.,
                pred1 = GetModelAndPredict(.$cyl, "model1", .),
                pred2 = GetModelAndPredict(.$cyl, "model2", .))) %>%
  ungroup


# # A tibble: 12 × 13
#      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb     pred1     pred2
# *  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>     <dbl>     <dbl>
# 1   22.8     4 108.0    93  3.85 2.320 18.61     1     1     4     1 103.11501 115.24903
# 2   17.3     8 275.8   180  3.07 3.730 17.60     0     0     3     3 356.19839 316.20091
# 3   18.1     6 225.0   105  2.76 3.460 20.22     1     0     3     1 200.10912 151.56750
# 4   21.0     6 160.0   110  3.90 2.875 17.02     0     1     4     4 195.69767 198.89904
# 5   32.4     4  78.7    66  4.08 2.200 19.47     1     1     4     1  87.99347  77.54320
# 6   26.0     4 120.3    91  4.43 2.140 16.70     0     1     5     2 101.99490 102.68042
# 7   15.8     8 351.0   264  4.22 3.170 14.50     0     1     5     4 365.97745 339.57501
# 8   24.4     4 146.7    62  3.69 3.190 20.00     1     0     4     2  85.75324 108.96473
# 9   27.3     4  79.0    66  4.08 1.935 18.90     1     1     4     1  87.99347  97.57442
# 10  33.9     4  71.1    65  4.22 1.835 19.90     1     1     4     1  87.43341  71.65166
# 11  22.8     4 140.8    95  3.92 3.150 22.90     1     0     4     2 104.23513 115.24903
# 12  18.7     8 360.0   175  3.15 3.440 17.02     0     0     3     2 355.61630 294.38507

But I found really interesting the nest, map approach as well:

library(tidyverse)

# reproducible results
set.seed(47)

# shuffle / randomise rows
mtcars2 = mtcars %>% sample_frac(1)

# split train / test
mtcars_train = mtcars2[1:20,]
mtcars_test = mtcars2[21:32,]

# for each cyl group create subsets and fit the models of interest using map
dt_models = mtcars_train %>%
  nest(-cyl) %>%
  mutate(model1 = map(data, ~lm(disp ~ hp, data = .)),
         model2 = map(data, ~lm(disp ~ mpg, data = .))) %>%
  rename(data_train = data) %>%
  print()

# join test data to be able to predict them
dt_models_and_test_data = mtcars_test %>%
  nest(-cyl) %>%
  inner_join(dt_models, by = "cyl") %>%
  rename(data_test = data) %>%
  print()

# predict test data using map2
dt_preds = dt_models_and_test_data %>%
  mutate(pred1 = map2(model1, data_test, predict.lm),
         pred2 = map2(model2, data_test, predict.lm)) %>%
  print()

# go back to a reasonable data frame using unnest on columns of interest
dt_preds_upd = dt_preds %>%
  unnest(data_test,pred1,pred2) %>%
  print()


# # A tibble: 12 × 13
#      cyl     pred1     pred2   mpg  disp    hp  drat    wt  qsec    vs    am  gear  carb
#    <dbl>     <dbl>     <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
# 1      4 103.11501 115.24903  22.8 108.0    93  3.85 2.320 18.61     1     1     4     1
# 2      4  87.99347  77.54320  32.4  78.7    66  4.08 2.200 19.47     1     1     4     1
# 3      4 101.99490 102.68042  26.0 120.3    91  4.43 2.140 16.70     0     1     5     2
# 4      4  85.75324 108.96473  24.4 146.7    62  3.69 3.190 20.00     1     0     4     2
# 5      4  87.99347  97.57442  27.3  79.0    66  4.08 1.935 18.90     1     1     4     1
# 6      4  87.43341  71.65166  33.9  71.1    65  4.22 1.835 19.90     1     1     4     1
# 7      4 104.23513 115.24903  22.8 140.8    95  3.92 3.150 22.90     1     0     4     2
# 8      8 356.19839 316.20091  17.3 275.8   180  3.07 3.730 17.60     0     0     3     3
# 9      8 365.97745 339.57501  15.8 351.0   264  4.22 3.170 14.50     0     1     5     4
# 10     8 355.61630 294.38507  18.7 360.0   175  3.15 3.440 17.02     0     0     3     2
# 11     6 200.10912 151.56750  18.1 225.0   105  2.76 3.460 20.22     1     0     3     1
# 12     6 195.69767 198.89904  21.0 160.0   110  3.90 2.875 17.02     0     1     4     4
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!