XOR Neural Network in Java

一曲冷凌霜 提交于 2019-12-19 02:06:22

问题


I'm trying to implement and train a five neuron neural network with back propagation for the XOR function in Java. My code (please excuse it's hideousness):

public class XORBackProp {

private static final int MAX_EPOCHS = 500;

//weights
private static double w13, w23, w14, w24, w35, w45;
private static double theta3, theta4, theta5;
//neuron outputs
private static double gamma3, gamma4, gamma5;
//neuron error gradients
private static double delta3, delta4, delta5;
//weight corrections
private static double dw13, dw14, dw23, dw24, dw35, dw45, dt3, dt4, dt5;
//learning rate
private static double alpha = 0.1;
private static double error;
private static double sumSqrError;
private static int epochs = 0;
private static boolean loop = true;

private static double sigmoid(double exponent)
{
    return (1.0/(1 + Math.pow(Math.E, (-1) * exponent)));
}

private static void activateNeuron(int x1, int x2, int gd5)
{
    gamma3 = sigmoid(x1*w13 + x2*w23 - theta3);
    gamma4 = sigmoid(x1*w14 + x2*w24 - theta4);
    gamma5 = sigmoid(gamma3*w35 + gamma4*w45 - theta5);

    error = gd5 - gamma5;

    weightTraining(x1, x2);
}

private static void weightTraining(int x1, int x2)
{
    delta5 = gamma5 * (1 - gamma5) * error;
    dw35 = alpha * gamma3 * delta5;
    dw45 = alpha * gamma4 * delta5;
    dt5 = alpha * (-1) * delta5;

    delta3 = gamma3 * (1 - gamma3) * delta5 * w35;
    delta4 = gamma4 * (1 - gamma4) * delta5 * w45;

    dw13 = alpha * x1 * delta3;
    dw23 = alpha * x2 * delta3;
    dt3 = alpha * (-1) * delta3;
    dw14 = alpha * x1 * delta4;
    dw24 = alpha * x2 * delta4;
    dt4 = alpha * (-1) * delta4;

    w13 = w13 + dw13;
    w14 = w14 + dw14;
    w23 = w23 + dw23;
    w24 = w24 + dw24;
    w35 = w35 + dw35;
    w45 = w45 + dw45;
    theta3 = theta3 + dt3;
    theta4 = theta4 + dt4;
    theta5 = theta5 + dt5;
}

public static void main(String[] args)
{

    w13 = 0.5;
    w14 = 0.9;
    w23 = 0.4;
    w24 = 1.0;
    w35 = -1.2;
    w45 = 1.1;
    theta3 = 0.8;
    theta4 = -0.1;
    theta5 = 0.3;

    System.out.println("XOR Neural Network");

    while(loop)
    {
        activateNeuron(1,1,0);
        sumSqrError = error * error;
        activateNeuron(0,1,1);
        sumSqrError += error * error;
        activateNeuron(1,0,1);
        sumSqrError += error * error;
        activateNeuron(0,0,0);
        sumSqrError += error * error;

        epochs++;

        if(epochs >= MAX_EPOCHS)
        {
            System.out.println("Learning will take more than " + MAX_EPOCHS + " epochs, so program has terminated.");
            System.exit(0);
        }

        System.out.println(epochs + " " + sumSqrError);

        if (sumSqrError < 0.001)
        {
            loop = false;
        }
    }
}
}

If it helps any, here's a diagram of the network.

The initial values for all the weights and the learning rate are taken straight from an example in my textbook. The goal is to train the network until the sum of the squared errors is less than .001. The textbook also gives the values of all the weights after the first iteration (1,1,0) and I've tested my code and its results match the textbook's results perfectly. But according to the book, this should only take 224 epochs to converge. But when I run it, it always reaches MAX_EPOCHS unless it is set to several thousand. What am I doing wrong?


回答1:


    //Add this in the constants declaration section.
    private static double alpha = 3.8, g34 = 0.13, g5 = 0.21;

    // Add this in activate neuron
    gamma3 = sigmoid(x1 * w13 + x2 * w23 - theta3);
    gamma4 = sigmoid(x1 * w14 + x2 * w24 - theta4);        
    if (gamma3 > 1 - g34 ) {gamma3 = 1;}
    if (gamma3 < g34) {gamma3 = 0;}
    if (gamma4 > 1- g34) {gamma4 = 1;}
    if (gamma4 < g34) {gamma4 = 0;}   
    gamma5 = sigmoid(gamma3 * w35 + gamma4 * w45 - theta5);
    if (gamma5 > 1 - g5) {gamma5 = 1;}
    if (gamma5 < g5) {gamma5 = 0;}

ANN should learn in 66 iterations, but is on the brink of divergence.




回答2:


Try making rounding of gamma3, gamma4, gamma5 while in activation phase for instace:

if (gamma3 > 0.7) gamma3 = 1;
if (gamma3 < 0.3) gamma3 = 0;

and rise little bit learnig variable ( alpha )

alpha = 0.2;

learning ends in 466 epochs.

Of course if u make bigger rounding and higher alpha u set u can achieve even better result than 224.




回答3:


Whole point of this network is to show how to deal with a situation when grouping isn't based on "top = yes, bottom = no", but rather there is a central line (going through points (0,1) and (1,0) in this case) and if value is close to the line, then answer is "yes", while if it is far, then answer is "no". You can't cluster such system with just one layer. However two layers is enough.



来源:https://stackoverflow.com/questions/9662746/xor-neural-network-in-java

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!