Keras: What if the size of data is not divisible by batch_size?

前端 未结 1 1497
渐次进展
渐次进展 2020-12-19 00:43

I am new to Keras and just started working on some examples. I am dealing with the following problem: I have 4032 samples and use about 650 of them as for the fit or basical

相关标签:
1条回答
  • 2020-12-19 01:12

    The simplest solution is to use fit_generator instead of fit. I write a simple dataloader class that can be inherited to do more complex stuff. It would look something like this with get_next_batch_data redefined to whatever your data is including stuff like augmentation etc..

    class BatchedLoader():
        def __init__(self):
            self.possible_indices = [0,1,2,...N] #(say N = 33)
            self.cur_it = 0
            self.cur_epoch = 0
    
        def get_batch_indices(self):
            batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
            # If len(batch_indices) < batchsize, the you've reached the end
            # In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
            # And add remaining K = batchsize - len(batch_indices) to batch_indices
    
    
        def get_next_batch_data(self):
            # batch_indices = self.get_batch_indices()
            # The data points corresponding to those indices will be your next batch data
    
    0 讨论(0)
提交回复
热议问题