Linear Regression model building and prediction by group in R

最后都变了- 提交于 2021-01-07 06:58:43

问题


I'm trying to build several models based on subsets (groups) and generate their fits. In other words, taking my attempts below into consideration, I'm trying to build models that are country specific. Unfortunately in my attempts I'm only able to take the entire dataset into consideration to build the models instead of restricting it to the groups of countries in the datasets. Could you please help me resolve this problem?

In the first case I'm doing some sort of cross validation to generate the predictions. In the second case I'm not. Both my attempts seem to fail.



library(modelr)
install.packages("gapminder")
library(gapminder)                           
data(gapminder) 

#CASE 1
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)

gapminder %>% group_by(continent, country) %>%
  bind_cols(
    map(1:nrow(gapminder), function(i) {
      map_dfc(models, function(model) {
        training <- gapminder[-i, ] 
        fit <- lm(model, data = training)
        
        validation <- gapminder[i, ]
        predict(fit, newdata = validation)
        
      })
    }) %>%
      bind_rows()
  )


#CASE 2
model1 <- lm(lifeExp ~ pop, data = gapminder)
model2 <- lm(lifeExp ~ pop + gdpPercap, data = gapminder)

models <- list(fit_model1 = model1,fit_model2 = model2)


for (m in names(models)) {
  gapminder[[m]] <- predict(models[[m]], gapminder %>% group_by(continent, country) )
  
}


回答1:


The tidyverse solution to modeling by group is to use:

  • tidyr::nest() to group the variables
  • dplyr::mutate() together with purrr::map() to create models by group
  • broom::tidy() or broom::augment() to generate model summaries and predictions
  • tidyr::unnest() and dplyr::filter() to get summaries and predictions by group

Here's an example. It doesn't do the same as the code in your question, but I think it will be helpful nevertheless.

This code generates the linear model lifeExp ~ pop by country and the fitted (predicted) values for each model.

library(tidyverse)
library(broom)
library(gapminder)

gapminder_lm <- gapminder %>% 
  nest(data = c(year, lifeExp, pop, gdpPercap)) %>% 
  mutate(model = map(data, ~lm(lifeExp ~ pop, .)), 
         fitted = map(model, augment)) %>% 
  unnest(fitted)

gapminder_lm

# A tibble: 1,704 x 12
   country     continent data              model  lifeExp      pop .fitted .resid .std.resid   .hat .sigma  .cooksd
   <fct>       <fct>     <list>            <list>   <dbl>    <int>   <dbl>  <dbl>      <dbl>  <dbl>  <dbl>    <dbl>
 1 Afghanistan Asia      <tibble [12 x 4]> <lm>      28.8  8425333    33.2 -4.41     -1.54   0.182    2.92 0.262   
 2 Afghanistan Asia      <tibble [12 x 4]> <lm>      30.3  9240934    33.7 -3.35     -1.15   0.161    3.11 0.128   
 3 Afghanistan Asia      <tibble [12 x 4]> <lm>      32.0 10267083    34.3 -2.27     -0.773  0.139    3.24 0.0482  
 4 Afghanistan Asia      <tibble [12 x 4]> <lm>      34.0 11537966    35.0 -0.985    -0.331  0.116    3.32 0.00720 
 5 Afghanistan Asia      <tibble [12 x 4]> <lm>      36.1 13079460    35.9  0.193     0.0641 0.0969   3.34 0.000220
 6 Afghanistan Asia      <tibble [12 x 4]> <lm>      38.4 14880372    36.9  1.50      0.496  0.0849   3.30 0.0114  
 7 Afghanistan Asia      <tibble [12 x 4]> <lm>      39.9 12881816    35.8  4.07      1.35   0.0989   3.02 0.101   
 8 Afghanistan Asia      <tibble [12 x 4]> <lm>      40.8 13867957    36.4  4.47      1.48   0.0902   2.95 0.108   
 9 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.7 16317921    37.8  3.91      1.29   0.0838   3.05 0.0759  
10 Afghanistan Asia      <tibble [12 x 4]> <lm>      41.8 22227415    41.2  0.588     0.202  0.157    3.33 0.00380 
# ... with 1,694 more rows

This has the advantage of keeping everything in a tidy data frame, which can be filtered for the data of interest.

For example, filter for Egypt and plot real versus predicted values:

gapminder_lm %>% 
  filter(country == "Egypt") %>% 
  ggplot(aes(lifeExp, .fitted)) + 
  geom_point() + 
  labs(title = "Egypt")



来源:https://stackoverflow.com/questions/65433570/linear-regression-model-building-and-prediction-by-group-in-r

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!