I want to practice keras by code a xor, but the result is not right, the followed is my code, thanks for everybody to help me.
from keras.models import Seque
XOR training with Keras
Below, the minimal neuron network architecture required to learn XOR which should be a (2,2,1) network. In fact, if maths shows that the (2,2,1) network can solve the XOR problem, but maths doesn't show that the (2,2,1) network is easy to train. It could sometimes takes a lot of epochs (iterations) or does not converge to the global minimum. That said, I've got easily good results with (2,3,1) or (2,4,1) network architectures.
The problem seems to be related to the existence of many local minima. Look at this 1998 paper, «Learning XOR: exploring the space of a classic problem» by Richard Bland. Furthermore weights initialization with random number between 0.5 and 1.0 helps to converge.
It works fine with Keras or TensorFlow using loss function 'mean_squared_error', sigmoid activation and Adam optimizer. Even with pretty good hyperparameters, I observed that the learned XOR model is trapped in a local minimum about 15% of the time.
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from tensorflow.keras import initializers
import numpy as np
X = np.array([[0,0],[0,1],[1,0],[1,1]])
y = np.array([[0],[1],[1],[0]])
def initialize_weights(shape, dtype=None):
return np.random.normal(loc = 0.75, scale = 1e-2, size = shape)
model = Sequential()
model.add(Dense(2,
activation='sigmoid',
kernel_initializer=initialize_weights,
input_dim=2))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='mean_squared_error',
optimizer='adam',
metrics=['accuracy'])
print("*** Training... ***")
model.fit(X, y, batch_size=4, epochs=10000, verbose=0)
print("*** Training done! ***")
print("*** Model prediction on [[0,0],[0,1],[1,0],[1,1]] ***")
print(model.predict_proba(X))
*** Training... ***
*** Training done! ***
*** Model prediction on [[0,0],[0,1],[1,0],[1,1]] ***
[[0.08662204] [0.9235283 ] [0.92356336] [0.06672956]]