问题
Almost all of the machine learning packages / functions in R allow you to obtain cross-validation performance metrics while training a model.
From what I can tell, the only way to do cross-validation with xgboost is to setup a xgb.cv
statement like this:
clf <- xgb.cv( params = param,
data = dtrain,
nrounds = 1000,
verbose = 1,
watchlist = watchlist,
maximize = FALSE,
nfold = 2,
nthread = 2,
prediction = T
)
but even with that option of prediction = T
you're merely getting the prediction results from your training data. I don't see a way to use the resulting object (clf
in this example) in a predict
statement with new data.
Is my understanding accurate and is there any work-around?
回答1:
I believe your understanding is accurate, and that there is no setting to save the models from cross validation.
For more control over cross validation, you can train xgboost
models with caret
(see more details on the trainControl
function here http://topepo.github.io/caret/training.html)
Yet unless I'm mistaken, caret
also lacks an option to save each CV model for use to predict later on (although you can manually specify metrics you wish to evaluate them on). Depending on what your reason for using the CV models to predict on new data, you could either 1) retrieve the indices of the CV models from the final model, to retrain that particular one model (without crossvalidation, but with the same seed) on just that subset of the data (from the $control$index
list within the object produced by caret
's train
function:
> library(MASS) # For the Boston dataset
> library(caret)
> ctrl <- trainControl(method = "cv", number = 3, savePred=T)
> mod <- train(medv~., data = Boston, method = "xgbLinear", trControl = ctrl)
> str(mod$control$index)
List of 3
$ Fold1: int [1:336] 2 3 4 6 8 9 13 14 17 19 ...
$ Fold2: int [1:338] 1 2 4 5 6 7 9 10 11 12 ...
$ Fold3: int [1:338] 1 3 5 7 8 10 11 12 14 15 ...
or 2) manually cross-validate with lapply
or a for
loop to save all the models you create. The createFolds
family of functions in caret
is a useful tool for choosing the cross validation folds.
来源:https://stackoverflow.com/questions/36523319/is-it-possible-to-cross-validate-and-save-the-cross-validated-model-with-xgboost