MATLAB: 10 fold cross Validation without using existing functions

前端 未结 2 2228
难免孤独
难免孤独 2020-12-30 17:30

I have a matrix (I guess in MatLab you call it a struct) or data structure:

  data: [150x4 double]
labels: [150x1 double]

here is out my ma

2条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-30 17:51

    Here is my take for this cross validation. I create dummy data using magic(10) also I create labels randomly. Idea is following , we get our data and labels and combine them with random column. Consider following dummy code.

    >> data = magic(4)
    
    data =
    
        16     2     3    13
         5    11    10     8
         9     7     6    12
         4    14    15     1
    
    >> dataRowNumber = size(data,1)
    
    dataRowNumber =
    
         4
    
    >> randomColumn = rand(dataRowNumber,1)
    
    randomColumn =
    
        0.8147
        0.9058
        0.1270
        0.9134
    
    
    >> X = [ randomColumn data]
    
    X =
    
        0.8147   16.0000    2.0000    3.0000   13.0000
        0.9058    5.0000   11.0000   10.0000    8.0000
        0.1270    9.0000    7.0000    6.0000   12.0000
        0.9134    4.0000   14.0000   15.0000    1.0000
    

    If we sort X according column 1, we sort our data randomly. This will give us cross validation randomness. Then next thing is to divide X according to cross validation percentage. Accomplishing this for one case easy enough. Lets consider %75 percent is train case and %25 percent is test case. Our size here is 4, then 3/4 = %75 and 1/4 is %25.

    testDataset = X(1,:)
    trainDataset = X(2:4,:)
    

    But accomplishing this a bit harder for N cross folds. Since we need to make this N times. For loop is necessary for this. For 5 cross folds. I get , in first f

    1. 1st fold : 1 2 for test, 3:10 for train
    2. 2nd fold : 3 4 for test, 1 2 5:10 for train
    3. 3rd fold : 5 6 for test, 1:4 7:10 for train
    4. 4th fold : 7 8 for test, 1:6 9:10 for train
    5. 5th fold : 9 10 for test, 1:8 for train

    Following code is an example for this process:

    data = magic(10);
    dataRowNumber = size(data,1);
    labels= rand(dataRowNumber,1) > 0.5;
    randomColumn = rand(dataRowNumber,1);
    
    X = [ randomColumn data labels];
    
    
    SortedData = sort(X,1);
    
    crossValidationFolds = 5;
    numberOfRowsPerFold = dataRowNumber / crossValidationFolds;
    
    crossValidationTrainData = [];
    crossValidationTestData = [];
    for startOfRow = 1:numberOfRowsPerFold:dataRowNumber
        testRows = startOfRow:startOfRow+numberOfRowsPerFold-1;
        if (startOfRow == 1)
            trainRows = [max(testRows)+1:dataRowNumber];
            else
            trainRows = [1:startOfRow-1 max(testRows)+1:dataRowNumber];
        end
        crossValidationTrainData = [crossValidationTrainData ; SortedData(trainRows ,:)];
        crossValidationTestData = [crossValidationTestData ;SortedData(testRows ,:)];
    
    end
    

提交回复
热议问题