Cross-validation in LightGBM

佐手、 提交于 2019-12-17 19:18:50

问题


After reading through LightGBM's documentation on cross-validation, I'm hoping this community can shed light on cross-validating results and improving our predictions using LightGBM. How are we supposed to use the dictionary output from lightgbm.cv to improve our predictions?

Here's an example - we train our cv model using the code below:

cv_mod = lgb.cv(params, 
                d_train, 
                500, 
                nfold = 10, 
                early_stopping_rounds = 25,
                stratified = True)

How can we use the parameters found from the best iteration of the above code to predict an output? In this case, cv_mod has no "predict" method like lightgbm.train, and the dictionary output from lightgbm.cvthrows an error when used in lightgbm.train.predict(..., pred_parameters = cv_mod).

Am I missing an important transformation step?


回答1:


In general, the purpose of CV is NOT to do hyperparameter optimisation. The purpose is to evaluate performance of model-building procedure.

A basic train/test split is conceptually identical to a 1-fold CV (with a custom size of the split in contrast to the 1/K train size in the k-fold CV). The advantage of doing more splits (i.e. k>1 CV) is to get more information about the estimate of generalisation error. There is more info in a sense of getting the error + stat uncertainty. There is an excellent discussion on CrossValidated (start with the links added to the question, which cover the same question, but formulated in a different way). It covers nested cross validation and is absolutely not straightforward. But if you will wrap your head around the concept in general, this will help you in various non-trivial situations. The idea that you have to take away is: The purpose of CV is to evaluate performance of model-building procedure.

Keeping that idea in mind, how does one approach hyperparameter estimation in general (not only in LightGBM)?

  • You want to train a model with a set of parameters on some data and evaluate each variation of the model on an independent (validation) set. Then you intend to choose the best parameters by choosing the variant that gives the best evaluation metric of your choice.
  • This can be done with a simple train/test split. But evaluated performance, and thus the choice of the optimal model parameters, might be just a fluctuation on a particular split.
  • Thus, you can evaluate each of those models more statistically robust averaging evaluation over several train/test splits, i.e k-fold CV.

Then you can make a step further and say that you had an additional hold-out set, that was separated before hyperparameter optimisation was started. This way you can evaluate the chosen best model on that set to measure the final generalisation error. However, you can make even step further and instead of having a single test sample you can have an outer CV loop, which brings us to nested cross validation.

Technically, lightbgm.cv() allows you only to evaluate performance on a k-fold split with fixed model parameters. For hyper-parameter tuning you will need to run it in a loop providing different parameters and recoding averaged performance to choose the best parameter set. after the loop is complete. This interface is different from sklearn, which provides you with complete functionality to do hyperparameter optimisation in a CV loop. Personally, I would recommend to use the sklearn-API of lightgbm. It is just a wrapper around the native lightgbm.train() functionality, thus it is not slower. But it allows you to use the full stack of sklearn toolkit, thich makes your life MUCH easier.




回答2:


If you're happy with your CV results, you just use those parameters to call the 'lightgbm.train' method. Like @pho said, CV is usually just for param tuning. You don't use the actual CV object for predictions.




回答3:


You should use CV for parameter optimization.

If your model performs well on all folds use these parameters to train on the whole training set. Then evaluate that model on the external test set.



来源:https://stackoverflow.com/questions/46456381/cross-validation-in-lightgbm

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!