matlab实现梯度下降法(Gradient Descent)的一个例子

大兔子大兔子 提交于 2020-04-06 11:00:43

  在此记录使用matlab作梯度下降法(GD)求函数极值的一个例子:

  问题设定: 

  1. 我们有一个$n$个数据点,每个数据点是一个$d$维的向量,向量组成一个data矩阵$\mathbf{X}\in \mathbb{R}^{n\times d}$,这是我们的输入特征矩阵。

  2. 我们有一个响应的响应向量$\mathbf{y}\in \mathbb{R}^n$。

  3. 我们将使用线性模型来fit上述数据。因此我们将优化问题形式化成如下形式:$$\arg\min_{\mathbf{w}}f(\mathbf{w})=\frac{1}{n}\|\mathbf{y}-\mathbf{\overline{X}}\mathbf{w}\|_2^2$$

  其中$\mathbf{\overline{X}}=(\mathbf{1,X})\in \mathbb{R}^{n\times (d+1)}$ and $\mathbf{w}=(w_0,w_1,...,w_d)^\top\in \mathbb{R}^{d+1}$

  显然这是一个回归问题,我们的目标从通俗意义上讲就是寻找合适的权重向量$\mathbf{w}$,使得线性模型能够拟合的更好。

  处理:

  1. 按列对数据矩阵进行最大最小归一化,该操作能够加快梯度下降的速度,同时保证了输入的数值都在0和1之间。$\mathbf{x}_i$为$\mathbf{X}$的第i列。 $$z_{ij}\leftarrow \frac{x_{ij}-\min(\mathbf{x}_i)}{\max(\mathbf{x}_i)-\min(\mathbf{x}_i)}$$

  这样我们的优化问题得到了转化:$$\arg\min_{\mathbf{u}}g(\mathbf{w})=\frac{1}{n}\|\mathbf{y}-\mathbf{\overline{Z}}\mathbf{u}\|_2^2$$

  2. 考虑对目标函数的Lipschitz constants进行估计。因为我们使用线性回归模型,Lipschitz constants可以方便求得,这样便于我们在梯度下降法是选择合适的步长。假如非线性模型,可能要用其他方法进行估计(可选)。

  问题解决:

  使用梯度下降法进行问题解决,算法如下:

   我们可以看到,这里涉及到求目标函数$f$对$\mathbf{x}_k$的梯度。显然在这里,因为是线性模型,梯度的求解十分的简单:$$\nabla f(\mathbf{x}_k)-\frac{2}{n}\mathbf{\overline{X}}^\top(\mathbf{y}-\mathbf{\overline{X}}\mathbf{u}_k)$$

  进行思考,还有没有其他办法可以把这个梯度给弄出来?假如使用Tensorflow,Pytorch这样可以自动保存计算图的东东,那么梯度是可以由机器自动求出来的。当然在这里我是用matlab实现,暂时没有发现这样的利器,所以我认为假如在这里想求出梯度,那么我们必须要把梯度的闭式解搞出来,不然没法继续进行。

  下面是一段matlab的代码:  

function [g_result,u_result] = GD(N_Z,y,alpha,u0)
%GD 梯度下降法
%   Detailed explanation goes here
[n,~] = size(N_Z);
u = u0;
k = 0;
t = y-N_Z*u;
disp("g(u):");
while(合理的终止条件)
    k = k + 1;
    u = u - alpha * (-2/n)*N_Z'*t;
    t = y-N_Z*u;
    if(mod(k,10)==0)
        disp(t'*t/n);
    end
end
g_result = (y-N_Z * u)' * (y-N_Z * u)/n;
u_result = u;
end

  当然假如初始化的时候$u_0$选择不当,而且因为没有正则项,以上的算法将会有很大的问题:梯度消失,导致优化到最后的时候非常慢。我花了好多个小时才将loss讲到0.19左右,而闭式解法能够使得loss为0.06几,运行时间也不会难以忍受。

  问题推广:

  在这里,我们的问题是线性模型,回归问题。能否有更广的应用?思考后认为,只要需要优化的目标是标量,且该目标函数对输入向量的梯度容易求得即可。只是因为该算法简单朴素,可能在实际应用的时候会碰见恼人的梯度消失问题。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!