I put together a VAE using Dense Neural Networks in Keras. During model.fit
I get a dimension mismatch, but not sure what is throwing the code off. Below is what my code looks like
from keras.layers import Lambda, Input, Dense from keras.models import Model from keras.datasets import mnist from keras.losses import mse, binary_crossentropy from keras.utils import plot_model from keras import backend as K import keras import numpy as np import matplotlib.pyplot as plt import argparse import os (x_train, y_train), (x_test, y_test) = mnist.load_data() image_size = x_train.shape[1] original_dim = image_size * image_size x_train = np.reshape(x_train, [-1, original_dim]) x_test = np.reshape(x_test, [-1, original_dim]) x_train = x_train.astype('float32') / 255 x_test = x_test.astype('float32') / 255 # network parameters input_shape = (original_dim, ) intermediate_dim = 512 batch_size = 128 latent_dim = 2 epochs = 50 x = Input(batch_shape=(batch_size, original_dim)) h = Dense(intermediate_dim, activation='relu')(x) z_mean = Dense(latent_dim)(h) z_log_sigma = Dense(latent_dim)(h) def sampling(args): z_mean, z_log_sigma = args #epsilon = K.random_normal(shape=(batch, dim)) epsilon = K.random_normal(shape=(batch_size, latent_dim)) return z_mean + K.exp(z_log_sigma) * epsilon # note that "output_shape" isn't necessary with the TensorFlow backend # so you could write `Lambda(sampling)([z_mean, z_log_sigma])` z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma]) decoder_h = Dense(intermediate_dim, activation='relu') decoder_mean = Dense(original_dim, activation='sigmoid') h_decoded = decoder_h(z) x_decoded_mean = decoder_mean(h_decoded) print('X Decoded Mean shape: ', x_decoded_mean.shape) # end-to-end autoencoder vae = Model(x, x_decoded_mean) # encoder, from inputs to latent space encoder = Model(x, z_mean) # generator, from latent space to reconstructed inputs decoder_input = Input(shape=(latent_dim,)) _h_decoded = decoder_h(decoder_input) _x_decoded_mean = decoder_mean(_h_decoded) generator = Model(decoder_input, _x_decoded_mean) def vae_loss(x, x_decoded_mean): xent_loss = keras.metrics.binary_crossentropy(x, x_decoded_mean) kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1) return xent_loss + kl_loss vae.compile(optimizer='rmsprop', loss=vae_loss) print('X train shape: ', x_train.shape) print('X test shape: ', x_test.shape) vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, x_test))
Here is the stack trace that I see when model.fit
is called.
File "/home/asattar/workspace/projects/keras-examples/blogautoencoder/VariationalAutoEncoder.py", line 81, in <module> validation_data=(x_test, x_test)) File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training.py", line 1047, in fit validation_steps=validation_steps) File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training_arrays.py", line 195, in fit_loop outs = fit_function(ins_batch) File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2897, in __call__ return self._call(inputs) File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2855, in _call fetched = self._callable_fn(*array_vals) File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1439, in __call__ run_metadata_ptr) File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__ c_api.TF_GetCode(self.status.status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [128,784] vs. [96,784] [[{{node training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/BroadcastGradientArgs}} = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@train...ad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape, training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape_1)]]
Please note the "Incompatible shapes: [128,784] vs. [96,784]" in the stack trace" towards the end of the trace.