10 fold cross-validation in one-against-all SVM (using LibSVM)

前端 未结 2 2022
孤街浪徒
孤街浪徒 2020-12-01 07:12

I want to do a 10-fold cross-validation in my one-against-all support vector machine classification in MATLAB.

I tried to somehow mix these two rela

2条回答
  •  眼角桃花
    2020-12-01 07:37

    Mainly there are two reasons we do cross-validation:

    • as a testing method which gives us a nearly unbiased estimate of the generalization power of our model (by avoiding overfitting)
    • as a way of model selection (eg: find the best C and gamma parameters over the training data, see this post for an example)

    For the first case which we are interested in, the process involves training k models for each fold, and then training one final model over the entire training set. We report the average accuracy over the k-folds.

    Now since we are using one-vs-all approach to handle the multi-class problem, each model consists of N support vector machines (one for each class).


    The following are wrapper functions implementing the one-vs-all approach:

    function mdl = libsvmtrain_ova(y, X, opts)
        if nargin < 3, opts = ''; end
    
        %# classes
        labels = unique(y);
        numLabels = numel(labels);
    
        %# train one-against-all models
        models = cell(numLabels,1);
        for k=1:numLabels
            models{k} = libsvmtrain(double(y==labels(k)), X, strcat(opts,' -b 1 -q'));
        end
        mdl = struct('models',{models}, 'labels',labels);
    end
    
    function [pred,acc,prob] = libsvmpredict_ova(y, X, mdl)
        %# classes
        labels = mdl.labels;
        numLabels = numel(labels);
    
        %# get probability estimates of test instances using each 1-vs-all model
        prob = zeros(size(X,1), numLabels);
        for k=1:numLabels
            [~,~,p] = libsvmpredict(double(y==labels(k)), X, mdl.models{k}, '-b 1 -q');
            prob(:,k) = p(:, mdl.models{k}.Label==1);
        end
    
        %# predict the class with the highest probability
        [~,pred] = max(prob, [], 2);
        %# compute classification accuracy
        acc = mean(pred == y);
    end
    

    And here are functions to support cross-validation:

    function acc = libsvmcrossval_ova(y, X, opts, nfold, indices)
        if nargin < 3, opts = ''; end
        if nargin < 4, nfold = 10; end
        if nargin < 5, indices = crossvalidation(y, nfold); end
    
        %# N-fold cross-validation testing
        acc = zeros(nfold,1);
        for i=1:nfold
            testIdx = (indices == i); trainIdx = ~testIdx;
            mdl = libsvmtrain_ova(y(trainIdx), X(trainIdx,:), opts);
            [~,acc(i)] = libsvmpredict_ova(y(testIdx), X(testIdx,:), mdl);
        end
        acc = mean(acc);    %# average accuracy
    end
    
    function indices = crossvalidation(y, nfold)
        %# stratified n-fold cros-validation
        %#indices = crossvalind('Kfold', y, nfold);  %# Bioinformatics toolbox
        cv = cvpartition(y, 'kfold',nfold);          %# Statistics toolbox
        indices = zeros(size(y));
        for i=1:nfold
            indices(cv.test(i)) = i;
        end
    end
    

    Finally, here is simple demo to illustrate the usage:

    %# laod dataset
    S = load('fisheriris');
    data = zscore(S.meas);
    labels = grp2idx(S.species);
    
    %# cross-validate using one-vs-all approach
    opts = '-s 0 -t 2 -c 1 -g 0.25';    %# libsvm training options
    nfold = 10;
    acc = libsvmcrossval_ova(labels, data, opts, nfold);
    fprintf('Cross Validation Accuracy = %.4f%%\n', 100*mean(acc));
    
    %# compute final model over the entire dataset
    mdl = libsvmtrain_ova(labels, data, opts);
    

    Compare that against the one-vs-one approach which is used by default by libsvm:

    acc = libsvmtrain(labels, data, sprintf('%s -v %d -q',opts,nfold));
    model = libsvmtrain(labels, data, strcat(opts,' -q'));
    

提交回复
热议问题