I am trying to add a column of predictions to a dataframe that has a list column that contains an lm model. I adopted some of the code from this post.
I have made a toy example here:
library(dplyr)
library(purrr)
library(tidyr)
library(broom)
set.seed(1234)
exampleTable <- data.frame(
ind = c(rep(1:5, 5)),
dep = rnorm(25),
groups = rep(LETTERS[1:5], each = 5)
) %>%
group_by(groups) %>%
nest(.key=the_data) %>%
mutate(model = the_data %>% map(~lm(dep ~ ind, data = .))) %>%
mutate(Pred = map2(model, the_data, predict))
exampleTable <- exampleTable %>%
mutate(ind=row_number())
that gives me a tibble that looks like this:
# A tibble: 5 × 6
groups the_data model Pred ind
<fctr> <list> <list> <list> <int>
1 A <tibble [5 × 2]> <S3: lm> <dbl [5]> 1
2 B <tibble [5 × 2]> <S3: lm> <dbl [5]> 2
3 C <tibble [5 × 2]> <S3: lm> <dbl [5]> 3
4 D <tibble [5 × 2]> <S3: lm> <dbl [5]> 4
5 E <tibble [5 × 2]> <S3: lm> <dbl [5]> 5
to get a predicted value using the lm model for a specific group I can use this:
predict(exampleTable[1,]$model[[1]], slice(exampleTable, 1) %>% select(ind))
which produces this result:
> predict(exampleTable[1,]$model[[1]], slice(exampleTable, 1) %>% select(ind))
1
-0.4822045
I would like to have one new prediction for each group. I tried using purrr to get what I wanted:
exampleTable %>%
mutate(Prediction = map2(model, ind, predict))
but that gives the following error:
Error in mutate_impl(.data, dots) : object 'ind' not found
I was able to get the result I wanted with the following monstrosity:
exampleTable$Prediction <- NA
for(loop in seq_along(exampleTable$groups)){
lmod <- exampleTable[loop, ]$model[[1]]
obs <- filter(exampleTable, row_number()==loop) %>%
select(ind)
exampleTable[loop, ] $Prediction <- as.numeric(predict(lmod, obs))
}
that gives me a tibble that looks like this:
# A tibble: 5 × 6
groups the_data model Pred ind Prediction
<fctr> <list> <list> <list> <int> <dbl>
1 A <tibble [5 × 2]> <S3: lm> <dbl [5]> 1 -0.4822045
2 B <tibble [5 × 2]> <S3: lm> <dbl [5]> 2 -0.1357712
3 C <tibble [5 × 2]> <S3: lm> <dbl [5]> 3 -0.2455760
4 D <tibble [5 × 2]> <S3: lm> <dbl [5]> 4 0.4818425
5 E <tibble [5 × 2]> <S3: lm> <dbl [5]> 5 -0.3473236
There must be a way to do this in a 'tidy' way, but I just cant crack it.
You could take advantage of the newdata
argument to predict
.
I use map2_dbl
so it returns just the single value rather than a list.
mutate(Pred = map2_dbl(model, 1:5, ~predict(.x, newdata = data.frame(ind = .y))))
# A tibble: 5 x 4
groups the_data model Pred
<fctr> <list> <list> <dbl>
1 A <tibble [5 x 2]> <S3: lm> -0.4822045
2 B <tibble [5 x 2]> <S3: lm> -0.1357712
3 C <tibble [5 x 2]> <S3: lm> -0.2455760
4 D <tibble [5 x 2]> <S3: lm> 0.4818425
5 E <tibble [5 x 2]> <S3: lm> -0.3473236
If you add ind
to the dataset before prediction you can use that column instead of 1:5
.
mutate(ind = 1:5) %>%
mutate(Pred = map2_dbl(model, ind, ~predict(.x, newdata = data.frame(ind = .y) )))
# A tibble: 5 x 5
groups the_data model ind Pred
<fctr> <list> <list> <int> <dbl>
1 A <tibble [5 x 2]> <S3: lm> 1 -0.4822045
2 B <tibble [5 x 2]> <S3: lm> 2 -0.1357712
3 C <tibble [5 x 2]> <S3: lm> 3 -0.2455760
4 D <tibble [5 x 2]> <S3: lm> 4 0.4818425
5 E <tibble [5 x 2]> <S3: lm> 5 -0.3473236
来源:https://stackoverflow.com/questions/44709870/using-lm-in-list-column-to-predict-new-values-using-purrr