Caret obtain train & cv predictions from model to plot

梦想的初衷 提交于 2019-12-10 10:31:55

问题


I've trained a simple model:

mySim <- train(Event ~ .,
               method = 'rf',
               data = train,
               tuneGrid = tg)

Optimising the two nnet parameters weight_decay and size of the hidden layer. I'm new to trying out caret so what I would usually do is plot the train error and cv error for each model build. To do this, I'd need to have the predictive values of my train and validation pass.

This is the first time I've used cross validation so I'm a little unsure how I can go about getting the predictions from the train and hold-out set at each tuneGrid iteration.

If I have a grid search of length 3 (3 models to build) and 5-fold cross validation I assume I'm going to have 15 sets of train & holdout predictions for each model.

The plot I'm essentially looking to build is:

Where my y-axis is a performance metric, lets say entropy loss for the sake of classification with nnet and the size grid search values on the x-axis increases from 0 - max.

Is there a way in which I can extract the predicted values from the train / holdout set during trainControl cross validation?

I've looked through some of the attributes train returns but not sure if I'm missing something.

I know I lack code in this question but hopefully I've explained myself.

Update I am correct in assuming setting the following parameters in trainControl will return the predictions allowing me to create this plot:

  • returnResamp
  • savePredictions

回答1:


Here is an example on how to perform the requested operation with mlr library:

library(mlr)
library(mlbench) #for the data set

I will use the Sonar data set:

data(Sonar)

create a task:

task <- makeClassifTask(data = Sonar, target = "Class")

create a learner:

lrn <- makeLearner("classif.nnet", predict.type = "prob")

get all tune-able parameters for a learner:

getParamSet("classif.nnet")

set which ones you would like to tune and the range:

ps <- makeParamSet(
  makeIntegerParam("size", lower = 3, upper = 5),
  makeNumericParam("decay", lower = 0.1, upper = 0.2))

define resampling:

cross_val <- makeResampleDesc("RepCV",  
                              reps = 2, folds = 5, stratify  = TRUE, predict = "both")

how the search will be performed (grid in this case):

ctrl <- mlr::makeTuneControlGrid(resolution = 4L)

get everything together:

res.mbo <- tuneParams(lrn, task, cross_val, par.set = ps, control = ctrl, 
                      show.info = FALSE, measures = list(auc, setAggregation(auc, test.sd),  setAggregation(auc, train.mean), setAggregation(auc, train.sd)))

you can define many measures in a list (the first one is used to select hyper parameters all the other are just for show).

extract the results:

res <- mlr::generateHyperParsEffectData(res.mbo)$data

plot:

library(tidyverse)

res %>%
  gather(key, value, c(3,5)) %>%
  mutate(key = as.factor(key)) %>%
ggplot()+
  geom_point(aes(x = size, y = value, color = key))+
  geom_smooth(aes(x = size, y = value, color = key))+
  facet_wrap(~decay)

a bunch of warnings about geom_smooth since there are only 3 points per fit

and an example on how to do it in caret just on the hold out samples:

library(caret)

create a tune control

ctrl <- trainControl(
  method = "repeatedcv",
  number = 5,
  repeats = 2, 
  classProbs = TRUE,
  savePredictions = "all",
  returnResamp = "all",
  summaryFunction = twoClassSummary
)

create a grid of hyper parameters:

grid <- expand.grid(size = c(4, 5, 6), decay = seq(from = 0.1, to =  0.2, length.out = 4))

tune:

fit <- caret::train(Sonar[,1:60], Sonar$Class, 
                 method = 'nnet',
                 tuneGrid = grid, 
                 metric = 'ROC', 
                 trControl = ctrl) 

plot:

fit$results %>%
  ggplot()+
  geom_point(aes(x = size, y = ROC))+
  geom_smooth(aes(x = size, y = ROC))+
  facet_wrap(~decay)



来源:https://stackoverflow.com/questions/48754886/caret-obtain-train-cv-predictions-from-model-to-plot

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!