Estimating linear regression with Gradient Descent (Steepest Descent)

匿名 (未验证) 提交于 2019-12-03 10:24:21

问题:

Example data

X<-matrix(c(rep(1,97),runif(97)) , nrow=97, ncol=2) y<-matrix(runif(97), nrow= 97 , ncol =1) 

I have succeed in creating the cost function

COST<-function(theta,X,y){ ### Calculate half MSE      sum((X %*% theta - y)^2)/(2*length(y)) } 

How ever when I run this function , it seem to fail to converge over 100 iterations.

theta <- matrix (0, nrow=2,ncol=1) num.iters <- 1500 delta = 0   GD<-function(X,y,theta,alpha,num.iters){     for (i in num.iters){          while (max(abs(delta)) < tolerance){              error <- X %*% theta - y             delta <- (t(X) %*% error) / length(y)             theta <- theta - alpha * delta             cost_histo[i] <- COST(theta,X,y)             theta_histo[[i]] <- theta    }   }         return (list(cost_histo, theta_histo))   } 

Can someone help me ?

Cheers

回答1:

Algorithmic part of your implementation is correct. Problems lie in

  • The loop structure in GD is not right; the for loop is redundant and variables lack proper initialization.
  • Simple implementation of gradient descent by using a fixed alpha is dangerous. It is usually suggested that this alpha should be chosen small enough to hope that we always search down the objective function. However, this is rare in practice. For example, how small is sufficient? If it is small, then convergence speed is a problem; but if it is large, we may be trapped in a 'zig-zag' searching path and even a divergence!

Here is a robust version of Gradient Descent, for estimation of linear regression. The improvement comes from the step halving strategy, to avoid "zig-zag" or divergence. See comments along the code. Under this strategy, it is safe to use large alpha. Convergence is guaranteed.

# theta: initial guess on regression coef # alpha: initial step scaling factor GD <- function(X, y, theta, alpha) {   cost_histo <- numeric(0)   theta_histo <- numeric(0)   # an arbitrary initial gradient, to pass the initial while() check   delta <- rep(1, ncol(X))   # MSE at initial theta   old.cost <- COST(theta, X, y)   # main iteration loop   while (max(abs(delta)) > 1e-7) {     # gradient      error <- X %*% theta - y     delta <- crossprod(X, error) / length(y)     # proposal step     trial.theta <- theta - alpha * c(delta)     trial.cost <- COST(trial.theta, X, y)     # step halving to avoid divergence     while (trial.cost >= old.cost) {       trial.theta <- (theta + trial.theta) / 2       trial.cost <- COST(trial.theta, X, y)       }     # accept proposal     cost_histo <- c(cost_histo, trial.cost)     theta_histo <- c(theta_histo, trial.theta)     # update old.cost and theta     old.cost <- trial.cost     theta <- trial.theta     }   list(cost_histo, theta_histo = matrix(theta_histo, nrow = ncol(X)))   } 

On return,

  • the length of cost_histo tells you how many iterations have been taken (excluding step halving);
  • each column of theta_histo gives theta per iteration.

Step halving in fact speeds up convergence greatly. You can get more efficiency if you use a faster computation method for COST. (Most useful for large datasets. See https://stackoverflow.com/a/40228894/4891738)

COST<-function(theta,X, y) {   c(crossprod(X %*% theta - y)) /(2*length(y))   } 

Now, let's consider its implementation on your example X, y.

oo <- GD(X, y, c(0,0), 5) 

After 107 iterations it converges. We can view the trace of MSE:

plot(oo[[1]]) 

Note that at the first few steps, MSE decreases very fast, but then it is almost flat. This reveals the fundamental drawback of gradient descent algorithm: convergence gets slower and slower as we get nearer and nearer to the minimum.

Now, we extract the final estimated coefficient:

oo[[2]][, 107] 

We can also compare this with direct estimation by QR factorization:

.lm.fit(X, y)$coef 

They are pretty close.



回答2:

The crossprod makes it surprisingly slower then the previous methods :

Previous method (14 secs mean on 50 iterations):

Crossprod method (16 secs mean on 50 iterations):



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!