Retraining after Cross Validation with libsvm

前端 未结 2 559
一生所求
一生所求 2020-11-29 23:20

I know that Cross validation is used for selecting good parameters. After finding them, i need to re-train the whole data without the -v option.

But the problem i fa

2条回答
  •  夕颜
    夕颜 (楼主)
    2020-11-30 00:04

    The -v option here is really meant to be used as a way to avoid the overfitting problem (instead of using the whole data for training, perform an N-fold cross-validation training on N-1 folds and testing on the remaining fold, one at-a-time, then report the average accuracy). Thus it only returns the cross-validation accuracy (assuming you have a classification problem, otherwise mean-squared error for regression) as a scalar number instead of an actual SVM model.

    If you want to perform model selection, you have to implement a grid search using cross-validation (similar to the grid.py helper python script), to find the best values of C and gamma.

    This shouldn't be hard to implement: create a grid of values using MESHGRID, iterate overall all pairs (C,gamma) training an SVM model with say 5-fold cross-validation, and choosing the values with the best CV-accuracy...

    Example:

    %# read some training data
    [labels,data] = libsvmread('./heart_scale');
    
    %# grid of parameters
    folds = 5;
    [C,gamma] = meshgrid(-5:2:15, -15:2:3);
    
    %# grid search, and cross-validation
    cv_acc = zeros(numel(C),1);
    for i=1:numel(C)
        cv_acc(i) = svmtrain(labels, data, ...
                        sprintf('-c %f -g %f -v %d', 2^C(i), 2^gamma(i), folds));
    end
    
    %# pair (C,gamma) with best accuracy
    [~,idx] = max(cv_acc);
    
    %# contour plot of paramter selection
    contour(C, gamma, reshape(cv_acc,size(C))), colorbar
    hold on
    plot(C(idx), gamma(idx), 'rx')
    text(C(idx), gamma(idx), sprintf('Acc = %.2f %%',cv_acc(idx)), ...
        'HorizontalAlign','left', 'VerticalAlign','top')
    hold off
    xlabel('log_2(C)'), ylabel('log_2(\gamma)'), title('Cross-Validation Accuracy')
    
    %# now you can train you model using best_C and best_gamma
    best_C = 2^C(idx);
    best_gamma = 2^gamma(idx);
    %# ...
    

    contour_plot

提交回复
热议问题