Keras the simplest NN model: error in training.py with indices

允我心安 提交于 2021-02-08 06:14:31

问题


I have read this example https://github.com/fchollet/keras/blob/master/examples/mnist_mlp.py and decide to use this idea to my base because this is the simplest NN for Keras.

This is my base https://drive.google.com/file/d/0B-B3QUQOzGZ7WVhzQmRsOTB0eFE/view (you can download my csv file, it's only 83Kb )

This is picture my base:

base.shape = (891, 23)

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop, Adam
import numpy as np
import pandas as pd
from sklearn.cross_validation import train_test_split
from keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
from keras.utils import plot_model

base = pd.read_csv("mt.csv")

import pandas as pd

for col in base:
    if col != "Fare" and col != "Age":
    base[col]=base[col].astype(float)
X_train = base
y_train = base["Survived"]
del X_train["Survived"]

print("X_train=",X_train.shape)
print("y_train=", y_train.shape)

Out: X_train= (891, 22) y_train= (891,)

from sklearn.cross_validation import train_test_split

X_train, X_test , y_train, y_test = train_test_split(X_train, y_train, test_size=0.3, random_state=42)

batch_size = 4
num_classes = 2
epochs = 2

print(X_train.shape[1], 'train samples')
print(X_test.shape[1], 'test samples')

Out: 22 train samples 22 test samples

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Dense(40, activation='relu', input_shape=(21,)))
model.add(Dropout(0.2))
#model.add(Dense(20, activation='relu'))
#odel.add(Dropout(0.2))
model.add(Dense(2, activation='sigmoid'))

model.summary()

Out:

Layer (type) Output Shape Param

dense_1 (Dense) (None, 40) 880


dropout_1 (Dropout) (None, 40) 0


dense_2 (Dense) (None, 2) 82

model.compile(loss='binary_crossentropy',
          optimizer=Adam(),
          metrics=['accuracy'])

plot_model(model, to_file='model.png')

SVG(model_to_dot(model).create(prog='dot', format='svg'))

print("X_train.shape=", X_train.shape)
print("X_test=",X_test.shape)

history = model.fit(X_train, y_train,
                batch_size=batch_size,
                epochs=epochs,
                verbose=1,
                validation_data=(X_test, y_test))

Traceback (most recent call last): File "new.py", line 67, in validation_data=(X_test, y_test))

File "miniconda3/lib/python3.6/site-packages/keras/models.py", line 845, in fit initial_epoch=initial_epoch)

File "miniconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1405, in fit batch_size=batch_size)

File "miniconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1295, in _standardize_user_data exception_prefix='model input')

File "miniconda3/lib/python3.6/site-packages/keras/engine/training.py", line 133, in _standardize_input_data str(array.shape))

ValueError: Error when checking model input: expected dense_1_input to have shape (None, 21) but got array with shape (623, 22) [Finished in 5.1s with exit code 1]

How can I solve this error? I tried to change input shape, for example, to (20,) or to (22,), etc. Without success.

For example, if input_shape=(22,) I have the error: File "miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py", line 1873, in maybe_convert_indices raise IndexError("indices are out-of-bounds")


回答1:


input_shape should be same as the number of features in your data and that should be input_shape=(22,) in your case.

The IndexError is due to some different indexing in pandas dataframe, so convert your dataframe into a numpy matrix using as_matrix():

history = model.fit(X_train.as_matrix(), y_train,
            batch_size=batch_size,
            epochs=epochs,
            verbose=1,
            validation_data=(X_test.as_matrix(), y_test))


来源:https://stackoverflow.com/questions/43293832/keras-the-simplest-nn-model-error-in-training-py-with-indices

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!