Dimension mismatch in Keras during model.fit

本秂侑毒 提交于 2019-11-29 12:43:10

According to Keras: What if the size of data is not divisible by batch_size?, one should better use model.fit_generator rather than model.fit here.

To use model.fit_generator, one should define one's own generator object. Following is an example:

from keras.utils import Sequence
import math

class Generator(Sequence):
    # Class is a dataset wrapper for better training performance
    def __init__(self, x_set, y_set, batch_size=256):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.indices = np.arange(self.x.shape[0])

    def __len__(self):
        return math.floor(self.x.shape[0] / self.batch_size)

    def __getitem__(self, idx):
        inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = self.x[inds]
        batch_y = self.y[inds]
        return batch_x, batch_y

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

train_datagen = Generator(x_train, x_train, batch_size)
test_datagen = Generator(x_test, x_test, batch_size)

vae.fit_generator(train_datagen,
    steps_per_epoch=len(x_train)//batch_size,
    validation_data=test_datagen,
    validation_steps=len(x_test)//batch_size,
    epochs=epochs)

Code adopted from How to shuffle after each epoch using a custom generator?.

dllearn

Just tried to replicate and found out that when you define

x = Input(batch_shape=(batch_size, original_dim))

you're setting the batch size and it's causing a mismatch when it starts to validate. Change to

x = Input(shape=input_shape)

and you should be all set.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!