Error when checking model input: expected convolution2d_input_1 to have 4 dimensions, but got array with shape (32, 32, 3)

徘徊边缘 提交于 2019-11-28 22:04:28

问题


I want to train a deep network starting with the following layer:

model = Sequential()
model.add(Conv2D(32, 3, 3, input_shape=(32, 32, 3)))

using

history = model.fit_generator(get_training_data(),
                samples_per_epoch=1, nb_epoch=1,nb_val_samples=5,
                verbose=1,validation_data=get_validation_data()

with the following generator:

def get_training_data(self):
     while 1:
        for i in range(1,5):
            image = self.X_train[i]
            label = self.Y_train[i]
            yield (image,label)

(validation generator looks similar).

During training, I get the error:

Error when checking model input: expected convolution2d_input_1 to have 4 
dimensions, but got array with shape (32, 32, 3)

How can that be, with a first layer

 model.add(Conv2D(32, 3, 3, input_shape=(32, 32, 3)))

?


回答1:


The input shape you have defined is the shape of a single sample. The model itself expects some array of samples as input (even if its an array of length 1).

Your output really should be 4-d, with the 1st dimension to enumerate the samples. i.e. for a single image you should return a shape of (1, 32, 32, 3).

You can find more information here under "Convolution2D"/"Input shape"




回答2:


It is as simple as to Add one dimension, so I was going through the tutorial taught by Siraj Rawal on CNN Code Deployment tutorial, it was working on his terminal, but the same code was not working on my terminal, so I did some research about it and solved, I don't know if that works for you all. Here I have come up with solution;

Unsolved code lines which gives you problem:

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    print(x_train.shape)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols)
    input_shape = (img_rows, img_cols, 1)

Solved Code:

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    print(x_train.shape)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

Please share the feedback here if that worked for you.




回答3:


it depends on how you actually order your data,if its on a channel first basis then you should reshape your data: x_train=x_train.reshape(x_train.shape[0],channel,width,height)

if its channel last: x_train=s_train.reshape(x_train.shape[0],width,height,channel)




回答4:


x_train = x_train.reshape(-1,28, 28, 1)   #Reshape for CNN -  should work!!
x_test = x_test.reshape(-1,28, 28, 1)
history_cnn = cnn.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

Output:

Train on 60000 samples, validate on 10000 samples Epoch 1/5 60000/60000 [==============================] - 157s 3ms/step - loss: 0.0981 - acc: 0.9692 - val_loss: 0.0468 - val_acc: 0.9861 Epoch 2/5 60000/60000 [==============================] - 157s 3ms/step - loss: 0.0352 - acc: 0.9892 - val_loss: 0.0408 - val_acc: 0.9879 Epoch 3/5 60000/60000 [==============================] - 159s 3ms/step - loss: 0.0242 - acc: 0.9924 - val_loss: 0.0291 - val_acc: 0.9913 Epoch 4/5 60000/60000 [==============================] - 165s 3ms/step - loss: 0.0181 - acc: 0.9945 - val_loss: 0.0361 - val_acc: 0.9888 Epoch 5/5 60000/60000 [==============================] - 168s 3ms/step - loss: 0.0142 - acc: 0.9958 - val_loss: 0.0354 - val_acc: 0.9906



来源:https://stackoverflow.com/questions/41563720/error-when-checking-model-input-expected-convolution2d-input-1-to-have-4-dimens

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!