Using Prophet Package to Predict By Group in Dataframe in R

ε祈祈猫儿з 提交于 2019-11-28 18:58:35

Here is a solution using tidyr::nest to nest the data by group, fit the models in those groups using purrr::map and then retrieving the y-hat as requested. I took your code, but incorporated it into mutate calls that would compute new colums using purrr::map.

library(prophet)
library(dplyr)
library(purrr)
library(tidyr)

d1 <- df %>% 
  nest(-group) %>% 
  mutate(m = map(data, prophet)) %>% 
  mutate(future = map(m, make_future_dataframe, period = 7)) %>% 
  mutate(forecast = map2(m, future, predict))

Here is the output at this point:

d1
# A tibble: 2 × 5
   group              data          m                future
  <fctr>            <list>     <list>                <list>
1      A <tibble [30 × 2]> <S3: list> <data.frame [36 × 1]>
2      B <tibble [30 × 2]> <S3: list> <data.frame [36 × 1]>
# ... with 1 more variables: forecast <list>

Then I use unnest() to retrieve the data from the forecast column and select the y-hat value as requested.

d <- d1 %>% 
  unnest(forecast) %>% 
  select(ds, group, yhat)

And here is the output for the newly forecasted values:

d %>% group_by(group) %>% 
  top_n(7, ds)
Source: local data frame [14 x 3]
Groups: group [2]

           ds  group      yhat
       <date> <fctr>     <dbl>
1  2016-11-30      A 180.53422
2  2016-12-01      A 349.30277
3  2016-12-02      A 288.68215
4  2016-12-03      A 222.33501
5  2016-12-04      A 342.96654
6  2016-12-05      A 203.64625
7  2016-12-06      A 185.37395
8  2016-11-30      B 131.07827
9  2016-12-01      B 222.83703
10 2016-12-02      B 236.33555
11 2016-12-03      B 145.41001
12 2016-12-04      B 228.59687
13 2016-12-05      B 162.49244
14 2016-12-06      B  68.44477

I was looking for a solution for the same problem. I came up with the following code, which is a bit simpler than the accepted answer.

library(tidyr)
library(dplyr)
library(prophet)

data = df %>%  
       group_by(group) %>%
       do(predict(prophet(.), make_future_dataframe(prophet(.), periods = 7))) %>%
       select(ds, group, yhat)

And here are the predicted values

data %>% group_by(group) %>% 
         top_n(7, ds)

# A tibble: 14 x 3
# Groups:   group [2]
           ds  group     yhat
       <date> <fctr>    <dbl>
 1 2016-12-01      A 316.9709
 2 2016-12-02      A 258.2153
 3 2016-12-03      A 196.6835
 4 2016-12-04      A 346.2338
 5 2016-12-05      A 208.9083
 6 2016-12-06      A 216.5847
 7 2016-12-07      A 206.3642
 8 2016-12-01      B 230.0424
 9 2016-12-02      B 268.5359
10 2016-12-03      B 190.2903
11 2016-12-04      B 312.9019
12 2016-12-05      B 266.5584
13 2016-12-06      B 189.3556
14 2016-12-07      B 168.9791
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!