auto.arima() equivalent for python

匿名 (未验证) 提交于 2019-12-03 02:08:02

问题:

I am trying to predict weekly sales using ARMA ARIMA models. I could not find a function for tuning the order(p,d,q) in statsmodels. Currently R has a function auto.arima() which will tune the (p,d,q) parameters.

How do I go about choosing the right order for my model? Are there any libraries available in python for this purpose?

回答1:

You can implement a number of approaches:

  1. ARIMAResults include aic and bic. By their definition, (see here and here), these criteria penalize for the number of parameters in the model. So you may use these numbers to compare the models. Also scipy has optimize.brute which does grid search on the specified parameters space. So a workflow like this should work:

    def objfunc(order, exog, endog):     from statsmodels.tsa.arima_model import ARIMA     fit = ARIMA(endog, order, exog).fit()     return fit.aic()  from scipy.optimize import brute grid = (slice(1, 3, 1), slice(1, 3, 1), slice(1, 3, 1)) brute(objfunc, grid, args=(exog, endog), finish=None) 

    Make sure you call brute with finish=None.

  2. You may obtain pvalues from ARIMAResults. So a sort of step-forward algorithm is easy to implement where the degree of the model is increased across the dimension which obtains lowest p-value for the added parameter.

  3. Use ARIMAResults.predict to cross-validate alternative models. The best approach would be to keep the tail of the time series (say most recent 5% of data) out of sample, and use these points to obtain the test error of the fitted models.



回答2:

possible solution

df=pd.read_csv("http://vincentarelbundock.github.io/Rdatasets/csv/datasets/AirPassengers.csv")  # Define the p, d and q parameters to take any value between 0 and 2 p = d = q = range(0, 2) print(p)   import itertools import warnings  # Generate all different combinations of p, q and q triplets pdq = list(itertools.product(p, d, q)) print(pdq)  # Generate all different combinations of seasonal p, q and q triplets seasonal_pdq = [(x[0], x[1], x[2], 12) for x in list(itertools.product(p, d, q))]  print('Examples of parameter combinations for Seasonal ARIMA...') print('SARIMAX: {} x {}'.format(pdq[1], seasonal_pdq[1])) print('SARIMAX: {} x {}'.format(pdq[1], seasonal_pdq[2])) print('SARIMAX: {} x {}'.format(pdq[2], seasonal_pdq[3])) print('SARIMAX: {} x {}'.format(pdq[2], seasonal_pdq[4])) Examples of parameter combinations for Seasonal ARIMA... SARIMAX: (0, 0, 1) x (0, 0, 1, 12) SARIMAX: (0, 0, 1) x (0, 1, 0, 12) SARIMAX: (0, 1, 0) x (0, 1, 1, 12) SARIMAX: (0, 1, 0) x (1, 0, 0, 12)  y=df  #warnings.filterwarnings("ignore") # specify to ignore warning messages  for param in pdq:     for param_seasonal in seasonal_pdq:         try:             mod = sm.tsa.statespace.SARIMAX(y,                                             order=param,                                             seasonal_order=param_seasonal,                                             enforce_stationarity=False,                                             enforce_invertibility=False)              results = mod.fit()              print('ARIMA{}x{}12 - AIC:{}'.format(param, param_seasonal, results.aic))         except:             continue ARIMA(0, 0, 0)x(0, 0, 1, 12)12 - AIC:3618.0303991426763 ARIMA(0, 0, 0)x(0, 1, 1, 12)12 - AIC:2824.7439963684233 ARIMA(0, 0, 0)x(1, 0, 0, 12)12 - AIC:2942.2733127230185 ARIMA(0, 0, 0)x(1, 0, 1, 12)12 - AIC:2922.178151133141 ARIMA(0, 0, 0)x(1, 1, 0, 12)12 - AIC:2767.105066400224 ARIMA(0, 0, 0)x(1, 1, 1, 12)12 - AIC:2691.233398643673 ARIMA(0, 0, 1)x(0, 0, 0, 12)12 - AIC:3890.816777796087 ARIMA(0, 0, 1)x(0, 0, 1, 12)12 - AIC:3541.1171286722 ARIMA(0, 0, 1)x(0, 1, 0, 12)12 - AIC:3028.8377323188824 ARIMA(0, 0, 1)x(0, 1, 1, 12)12 - AIC:2746.77973129136 ARIMA(0, 0, 1)x(1, 0, 0, 12)12 - AIC:3583.523640623017 ARIMA(0, 0, 1)x(1, 0, 1, 12)12 - AIC:3531.2937768990187 ARIMA(0, 0, 1)x(1, 1, 0, 12)12 - AIC:2781.198675746594 ARIMA(0, 0, 1)x(1, 1, 1, 12)12 - AIC:2720.7023088205974 ARIMA(0, 1, 0)x(0, 0, 1, 12)12 - AIC:3029.089945668332 ARIMA(0, 1, 0)x(0, 1, 1, 12)12 - AIC:2568.2832251221016 ARIMA(0, 1, 0)x(1, 0, 0, 12)12 - AIC:2841.315781459511 ARIMA(0, 1, 0)x(1, 0, 1, 12)12 - AIC:2815.4011044132576 ARIMA(0, 1, 0)x(1, 1, 0, 12)12 - AIC:2588.533386513587 ARIMA(0, 1, 0)x(1, 1, 1, 12)12 - AIC:2569.9453272483315 ARIMA(0, 1, 1)x(0, 0, 0, 12)12 - AIC:3327.5177587522303 ARIMA(0, 1, 1)x(0, 0, 1, 12)12 - AIC:2984.716706112334 ARIMA(0, 1, 1)x(0, 1, 0, 12)12 - AIC:2789.128542154043 ARIMA(0, 1, 1)x(0, 1, 1, 12)12 - AIC:2537.0293659293943 ARIMA(0, 1, 1)x(1, 0, 0, 12)12 - AIC:2984.4555708516436 ARIMA(0, 1, 1)x(1, 0, 1, 12)12 - AIC:2939.460958374472 ARIMA(0, 1, 1)x(1, 1, 0, 12)12 - AIC:2578.7862352774437 ARIMA(0, 1, 1)x(1, 1, 1, 12)12 - AIC:2537.771484229265 ARIMA(1, 0, 0)x(0, 0, 0, 12)12 - AIC:3391.5248913820797 ARIMA(1, 0, 0)x(0, 0, 1, 12)12 - AIC:3038.142074281268 C:\Users\Dell\Anaconda3\lib\site-packages\statsmodels\base\model.py:496: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals   "Check mle_retvals", ConvergenceWarning) ARIMA(1, 0, 0)x(0, 1, 0, 12)12 - AIC:2839.809192263449 ARIMA(1, 0, 0)x(0, 1, 1, 12)12 - AIC:2588.50367175184 ARIMA(1, 0, 0)x(1, 0, 0, 12)12 - AIC:2993.4630440139595 ARIMA(1, 0, 0)x(1, 0, 1, 12)12 - AIC:2995.049216326931 ARIMA(1, 0, 0)x(1, 1, 0, 12)12 - AIC:2588.2463284315304 ARIMA(1, 0, 0)x(1, 1, 1, 12)12 - AIC:2592.80110502723 ARIMA(1, 0, 1)x(0, 0, 0, 12)12 - AIC:3352.0350133621478 ARIMA(1, 0, 1)x(0, 0, 1, 12)12 - AIC:3006.5493366627807 ARIMA(1, 0, 1)x(0, 1, 0, 12)12 - AIC:2810.6423724894516 ARIMA(1, 0, 1)x(0, 1, 1, 12)12 - AIC:2559.584031948852 ARIMA(1, 0, 1)x(1, 0, 0, 12)12 - AIC:2981.2250436794675 ARIMA(1, 0, 1)x(1, 0, 1, 12)12 - AIC:2959.3142304724834 ARIMA(1, 0, 1)x(1, 1, 0, 12)12 - AIC:2579.8245645892207 ARIMA(1, 0, 1)x(1, 1, 1, 12)12 - AIC:2563.13922589258 ARIMA(1, 1, 0)x(0, 0, 0, 12)12 - AIC:3354.7462930846423 ARIMA(1, 1, 0)x(0, 0, 1, 12)12 - AIC:3006.702997636003 ARIMA(1, 1, 0)x(0, 1, 0, 12)12 - AIC:2809.3844175191666 ARIMA(1, 1, 0)x(0, 1, 1, 12)12 - AIC:2558.484602766447 ARIMA(1, 1, 0)x(1, 0, 0, 12)12 - AIC:2959.885810636943 ARIMA(1, 1, 0)x(1, 0, 1, 12)12 - AIC:2960.712709764296 ARIMA(1, 1, 0)x(1, 1, 0, 12)12 - AIC:2557.945907092698 ARIMA(1, 1, 0)x(1, 1, 1, 12)12 - AIC:2559.274166458508 ARIMA(1, 1, 1)x(0, 0, 0, 12)12 - AIC:3326.3285511700374 ARIMA(1, 1, 1)x(0, 0, 1, 12)12 - AIC:2985.868532151721 ARIMA(1, 1, 1)x(0, 1, 0, 12)12 - AIC:2790.7677149967103 ARIMA(1, 1, 1)x(0, 1, 1, 12)12 - AIC:2538.820635541546 ARIMA(1, 1, 1)x(1, 0, 0, 12)12 - AIC:2963.2789505804294 ARIMA(1, 1, 1)x(1, 0, 1, 12)12 - AIC:2941.2436984747465 ARIMA(1, 1, 1)x(1, 1, 0, 12)12 - AIC:2559.8258191422606 ARIMA(1, 1, 1)x(1, 1, 1, 12)12 - AIC:2539.712354465328 

from https://www.digitalocean.com/community/tutorials/a-guide-to-time-series-forecasting-with-arima-in-python-3

also see https://github.com/decisionstats/pythonfordatascience/blob/master/time%2Bseries%20(1).ipynb



回答3:

I wrote these utility functions to directly calculate pdq values get_PDQ_parallel require three inputs data which is series with timestamp(datetime) as index. n_jobs will provide number of parallel processor. output will be dataframe with aic and bic value with order=(P,D,Q) in index p and q range is [0,12] while d is [0,1]

import statsmodels  from statsmodels import api as sm from sklearn.metrics import r2_score,mean_squared_error from sklearn.utils import check_array from functools import partial from multiprocessing import Pool def get_aic_bic(order,series):     aic=np.nan     bic=np.nan     #print(series.shape,order)     try:         arima_mod=statsmodels.tsa.arima_model.ARIMA(series,order=order,freq='H').fit(transparams=True,method='css')         aic=arima_mod.aic         bic=arima_mod.bic         print(order,aic,bic)     except:         pass     return aic,bic  def get_PDQ_parallel(data,n_jobs=7):     p_val=13     q_val=13     d_vals=2     pdq_vals=[ (p,d,q) for p in range(p_val) for d in range(d_vals) for q in range(q_val)]     get_aic_bic_partial=partial(get_aic_bic,series=data)     p = Pool(n_jobs)     res=p.map(get_aic_bic_partial, pdq_vals)       p.close()     return pd.DataFrame(res,index=pdq_vals,columns=['aic','bic'])  


回答4:

There is now a proper python package to do auto-arima. https://github.com/tgsmith61591/pyramid

Example usage: https://github.com/tgsmith61591/pyramid/blob/master/examples/quick_start_example.ipynb



回答5:

def evaluate_arima_model(X, arima_order):     # prepare training dataset     train_size = int(len(X) * 0.90)     train, test = X[0:train_size], X[train_size:]     history = [x for x in train]     # make predictions     predictions = list()     for t in range(len(test)):         model = ARIMA(history, order=arima_order)         model_fit = model.fit(disp=0)         yhat = model_fit.forecast()[0]         predictions.append(yhat)         history.append(test[t])     # calculate out of sample error     error = mean_squared_error(test, predictions)     return error  # evaluate combinations of p, d and q values for an ARIMA model def evaluate_models(dataset, p_values, d_values, q_values):     dataset = dataset.astype('float32')     best_score, best_cfg = float("inf"), None     for p in p_values:         for d in d_values:             for q in q_values:                 order = (p,d,q)                 try:                     mse = evaluate_arima_model(dataset, order)                     if mse < best_score:                         best_score, best_cfg = mse, order                     print('ARIMA%s MSE=%.3f' % (order,mse))                 except:                     continue     print('Best ARIMA%s MSE=%.3f' % (best_cfg, best_score))  # load dataset def parser(x):     return datetime.strptime('190'+x, '%Y-%m')    import datetime p_values = [4,5,6,7,8] d_values = [0,1,2] q_values = [2,3,4,5,6] warnings.filterwarnings("ignore") evaluate_models(train, p_values, d_values, q_values) 

This will give you the p,d,q values, then use the values for your ARIMA model



回答6:

actually

def objfunc(order,*params ):         from statsmodels.tsa.arima_model import ARIMA        p,d,q = order        fit = ARIMA(endog, order, exog).fit()       return fit.aic()     from scipy.optimize import brute grid = (slice(1, 3, 1), slice(1, 3, 1), slice(1, 3, 1)) brute(objfunc, grid, args=params, finish=None) 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!