Plot learning curves with caret package and R

我与影子孤独终老i 提交于 2019-11-30 04:00:59
  1. Caret will iteratively test lots of cv models for you if you set the trainControl() function and the parameters (e.g. mtry) using a tuneGrid(). Both of these are then passed as control options to the train() function. The specifics of the tuneGrid parameters (e.g. mtry, ntree) will be different for each model type.

  2. Yes the final trainFit model will contain the error rate (however you specified it) for all folds of your CV.

So you could specify e.g. a 10-fold CV times a grid with 10 values of mtry -which would be 100 iterations. You might want to go get a cup of tea or possibly lunch.

If this sounds complicated ... there is a very good example here - caret being one of the best documented packages about.

Here's my code on how I approached this issue of plotting a learning curve in R while using the Caret package to train your model. I use the Motor Trend Car Road Tests in R for illustrative purposes. To begin, I randomize and split the mtcars dataset into training and test sets. 21 records for training and 13 records for the test set. The response feature is mpg in this example.

# set seed for reproducibility
set.seed(7)

# randomize mtcars
mtcars <- mtcars[sample(nrow(mtcars)),]

# split iris data into training and test sets
mtcarsIndex <- createDataPartition(mtcars$mpg, p = .625, list = F)
mtcarsTrain <- mtcars[mtcarsIndex,]
mtcarsTest <- mtcars[-mtcarsIndex,]

# create empty data frame 
learnCurve <- data.frame(m = integer(21),
                     trainRMSE = integer(21),
                     cvRMSE = integer(21))

# test data response feature
testY <- mtcarsTest$mpg

# Run algorithms using 10-fold cross validation with 3 repeats
trainControl <- trainControl(method="repeatedcv", number=10, repeats=3)
metric <- "RMSE"

# loop over training examples
for (i in 3:21) {
    learnCurve$m[i] <- i

    # train learning algorithm with size i
    fit.lm <- train(mpg~., data=mtcarsTrain[1:i,], method="lm", metric=metric,
             preProc=c("center", "scale"), trControl=trainControl)        
    learnCurve$trainRMSE[i] <- fit.lm$results$RMSE

    # use trained parameters to predict on test data
    prediction <- predict(fit.lm, newdata = mtcarsTest[,-1])
    rmse <- postResample(prediction, testY)
    learnCurve$cvRMSE[i] <- rmse[1]
}

pdf("LinearRegressionLearningCurve.pdf", width = 7, height = 7, pointsize=12)

# plot learning curves of training set size vs. error measure
# for training set and test set
plot(log(learnCurve$trainRMSE),type = "o",col = "red", xlab = "Training set size",
          ylab = "Error (RMSE)", main = "Linear Model Learning Curve")
lines(log(learnCurve$cvRMSE), type = "o", col = "blue")
legend('topright', c("Train error", "Test error"), lty = c(1,1), lwd = c(2.5, 2.5),
       col = c("red", "blue"))

dev.off()

The output plot is as shown below:

At some point, probably after this question was asked, the caret package added the learning_curve_dat function which helps assess model performance across a range of training set sizes.

Here is the example from the function documentation:

library(caret)
set.seed(1412)
class_dat <- twoClassSim(1000)

set.seed(29510)
# NOTE learing_curve_dat below is not a typo
lda_data <- learing_curve_dat(dat = class_dat, 
                              outcome = "Class",
                              test_prop = 1/4, 
                              ## `train` arguments:
                              method = "lda", 
                              metric = "ROC",
                              trControl = trainControl(classProbs = TRUE, 
                                                       summaryFunction = twoClassSummary))

ggplot(lda_data, aes(x = Training_Size, y = ROC, color = Data)) + 
  geom_smooth(method = loess, span = .8)

The performance metric(s) are found for each Training_Size and saved in lda_data along with the Data variable ("Resampling", "Training", and optionally "Testing").

Here is a link to the function documentation: https://rdrr.io/cran/caret/man/learing_curve_dat.html

To be clear, this answers the first part of the question but not the second part.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!