How to use keras for XOR

╄→尐↘猪︶ㄣ 提交于 2019-11-30 05:43:27

问题


I want to practice keras by code a xor, but the result is not right, the followed is my code, thanks for everybody to help me.

from keras.models import Sequential
from keras.layers.core import Dense,Activation
from keras.optimizers import SGD
import numpy as np

model = Sequential()# two layers
model.add(Dense(input_dim=2,output_dim=4,init="glorot_uniform"))
model.add(Activation("sigmoid"))
model.add(Dense(input_dim=4,output_dim=1,init="glorot_uniform"))
model.add(Activation("sigmoid"))
sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.11, nesterov=True)
model.compile(loss='mean_absolute_error', optimizer=sgd)
print "begin to train"
list1 = [1,1]
label1 = [0]
list2 = [1,0]
label2 = [1]
list3 = [0,0]
label3 = [0]
list4 = [0,1]
label4 = [1] 
train_data = np.array((list1,list2,list3,list4)) #four samples for epoch = 1000
label = np.array((label1,label2,label3,label4))

model.fit(train_data,label,nb_epoch = 1000,batch_size = 4,verbose = 1,shuffle=True,show_accuracy = True)
list_test = [0,1]
test = np.array((list_test,list1))
classes = model.predict(test)
print classes

Output

[[ 0.31851079] [ 0.34130159]] [[ 0.49635666] [0.51274764]] 

回答1:


If I increase the number of epochs in your code to 50000 it does often converge to the right answer for me, just takes a little while :)

It does often get stuck, though. I get better convergence properties if I change your loss function to 'mean_squared_error', which is a smoother function.

I get still faster convergence if I use the Adam or RMSProp optimizers. My final compile line, which works:

model.compile(loss='mse', optimizer='adam')
...
model.fit(train_data, label, nb_epoch = 10000,batch_size = 4,verbose = 1,shuffle=True,show_accuracy = True)



回答2:


I used a single hidden layer with 4 hidden nodes, and it almost always converges to the right answer within 500 epochs. I used sigmoid activations.



来源:https://stackoverflow.com/questions/31556268/how-to-use-keras-for-xor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!