how to implement tensorflow's next_batch for own data

前端 未结 6 1571
失恋的感觉
失恋的感觉 2020-12-14 01:00

In the tensorflow MNIST tutorial the mnist.train.next_batch(100) function comes very handy. I am now trying to implement a simple classification myself. I have

6条回答
  •  一向
    一向 (楼主)
    2020-12-14 02:01

    The link you posted says: "we get a "batch" of one hundred random data points from our training set". In my example I use a global function (not a method like in your example) so there will be a difference in syntax.

    In my function you'll need to pass the number of samples wanted and the data array.

    Here is the correct code, which ensures samples have correct labels:

    import numpy as np
    
    def next_batch(num, data, labels):
        '''
        Return a total of `num` random samples and labels. 
        '''
        idx = np.arange(0 , len(data))
        np.random.shuffle(idx)
        idx = idx[:num]
        data_shuffle = [data[ i] for i in idx]
        labels_shuffle = [labels[ i] for i in idx]
    
        return np.asarray(data_shuffle), np.asarray(labels_shuffle)
    
    Xtr, Ytr = np.arange(0, 10), np.arange(0, 100).reshape(10, 10)
    print(Xtr)
    print(Ytr)
    
    Xtr, Ytr = next_batch(5, Xtr, Ytr)
    print('\n5 random samples')
    print(Xtr)
    print(Ytr)
    

    And a demo run:

    [0 1 2 3 4 5 6 7 8 9]
    [[ 0  1  2  3  4  5  6  7  8  9]
     [10 11 12 13 14 15 16 17 18 19]
     [20 21 22 23 24 25 26 27 28 29]
     [30 31 32 33 34 35 36 37 38 39]
     [40 41 42 43 44 45 46 47 48 49]
     [50 51 52 53 54 55 56 57 58 59]
     [60 61 62 63 64 65 66 67 68 69]
     [70 71 72 73 74 75 76 77 78 79]
     [80 81 82 83 84 85 86 87 88 89]
     [90 91 92 93 94 95 96 97 98 99]]
    
    5 random samples
    [9 1 5 6 7]
    [[90 91 92 93 94 95 96 97 98 99]
     [10 11 12 13 14 15 16 17 18 19]
     [50 51 52 53 54 55 56 57 58 59]
     [60 61 62 63 64 65 66 67 68 69]
     [70 71 72 73 74 75 76 77 78 79]]
    

提交回复
热议问题