gradient descent seems to fail

前端 未结 9 1929
忘掉有多难
忘掉有多难 2020-12-12 15:54

I implemented a gradient descent algorithm to minimize a cost function in order to gain a hypothesis for determining whether an image has a good quality. I did that in Octav

相关标签:
9条回答
  • 2020-12-12 16:26

    While not scalable like a vectorized version, a loop-based computation of a gradient descent should generate the same results. In the example above, the most probably case of the gradient descent failing to compute the correct theta is the value of alpha.

    With a verified set of cost and gradient descent functions and a set of data similar with the one described in the question, theta ends up with NaN values just after a few iterations if alpha = 0.01. However, when set as alpha = 0.000001, the gradient descent works as expected, even after 100 iterations.

    0 讨论(0)
  • 2020-12-12 16:28

    I think that your computeCost function is wrong. I attended NG's class last year and I have the following implementation (vectorized):

    m = length(y);
    J = 0;
    predictions = X * theta;
    sqrErrors = (predictions-y).^2;
    
    J = 1/(2*m) * sum(sqrErrors);
    

    The rest of the implementation seems fine to me, although you could also vectorize them.

    theta_1 = theta(1) - alpha * (1/m) * sum((X*theta-y).*X(:,1));
    theta_2 = theta(2) - alpha * (1/m) * sum((X*theta-y).*X(:,2));
    

    Afterwards you are setting the temporary thetas (here called theta_1 and theta_2) correctly back to the "real" theta.

    Generally it is more useful to vectorize instead of loops, it is less annoying to read and to debug.

    0 讨论(0)
  • 2020-12-12 16:31

    If you remember the first Pdf file for Gradient Descent form machine Learning course, you would take care of learning rate. Here is the note from the mentioned pdf.

    Implementation Note: If your learning rate is too large, J(theta) can di- verge and blow up', resulting in values which are too large for computer calculations. In these situations, Octave/MATLAB will tend to return NaNs. NaN stands fornot a number' and is often caused by undened operations that involve - infinity and +infinity.

    0 讨论(0)
提交回复
热议问题