Dimension mismatch in Keras during model.fit

后端 未结 2 720
自闭症患者
自闭症患者 2020-12-20 07:11

I put together a VAE using Dense Neural Networks in Keras. During model.fit I get a dimension mismatch, but not sure what is throwing the code off. Below is wha

相关标签:
2条回答
  • 2020-12-20 07:23

    According to Keras: What if the size of data is not divisible by batch_size?, one should better use model.fit_generator rather than model.fit here.

    To use model.fit_generator, one should define one's own generator object. Following is an example:

    from keras.utils import Sequence
    import math
    
    class Generator(Sequence):
        # Class is a dataset wrapper for better training performance
        def __init__(self, x_set, y_set, batch_size=256):
            self.x, self.y = x_set, y_set
            self.batch_size = batch_size
            self.indices = np.arange(self.x.shape[0])
    
        def __len__(self):
            return math.floor(self.x.shape[0] / self.batch_size)
    
        def __getitem__(self, idx):
            inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_x = self.x[inds]
            batch_y = self.y[inds]
            return batch_x, batch_y
    
        def on_epoch_end(self):
            np.random.shuffle(self.indices)
    
    train_datagen = Generator(x_train, x_train, batch_size)
    test_datagen = Generator(x_test, x_test, batch_size)
    
    vae.fit_generator(train_datagen,
        steps_per_epoch=len(x_train)//batch_size,
        validation_data=test_datagen,
        validation_steps=len(x_test)//batch_size,
        epochs=epochs)
    

    Code adopted from How to shuffle after each epoch using a custom generator?.

    0 讨论(0)
  • 2020-12-20 07:32

    Just tried to replicate and found out that when you define

    x = Input(batch_shape=(batch_size, original_dim))

    you're setting the batch size and it's causing a mismatch when it starts to validate. Change to

    x = Input(shape=input_shape)
    

    and you should be all set.

    0 讨论(0)
提交回复
热议问题