1、前言

    神经网络的优化方法有很多,前面学习了神经网络参数的初始化方法,好的初始化方法可以让网络训练的更快,也可能让网络收敛的更好。同样,好的优化方法同样具有这样的作用。注意:谈论优化方法时候,并没有修改损失函数,只是修改了网络学习过程中参数的更新方法。

    之前经常使用梯度下降法来优化网络,今天学习了梯度下降法的几个改进版本:Momentum、RMSprop、Adam

    网络学习中,梯度下降法的经典公式为:

改进神经网络 改进神经网络的方法_迭代

    梯度下降法一般有三种常见的形式(梯度更新的公式并没有本质性地改变):

(1)随机梯度下降法(stochastic gradient descent)

    该方法是最Navie的,每次计算一个样本,然后更新参数,显然容易受噪声干扰,从而导致损失下降曲线震荡,甚至出现损失中途一度上升的现象。另外,一次计算一个样本,效率很低,因此网络学习很慢。

(2)梯度下降法(Gradient descent)

    该方法相对于随机梯度下降法,指示的是一次训练所有样本,然后再更新参数。这种形式的好处是每次计算全部样本的梯度,然后用它们的均值去更新参数,能够有效避免噪声干扰现象。由于采用了向量化技术(全部样本几乎同时计算),numpy能够大大加速计算过程,所以计算速度较快。另外,由于能够有效避免或减小噪声干扰问题,所以学习率可以设置的大一些。

    缺点:数据量大的时候无法使用,几千个样本的数据量还是可以的,但如果是几万、几十万、甚至数百万的数据呢,显然无法之前全部计算,此时网络训练速度反而很慢。

(3)小批量梯度下降法(mini-batch gradient descent)

    综合上述两种方法,面对大数据情况,每次使用一部分数据来计算,并且优化网络,而不是只用一个数据或者全部数据。这样,计算速度得到了大大提高,同时又一定程度上避免了噪声干扰问题(还有有一些)。考虑到计算机储存方式问题,每个小批量的size一般取64、128、512、1024等,这样计算会更快(未验证)。

    简单写下流程:

    a、将全部样本随机打乱

    b、按照设定的batch-size划分数据为若干个batch,最后一个batch大小可能不是设定值,但一样要参与计算

    c、以batch为单位训练网络,每个batch计算后都要更新参数w和b,所有样本都做了一边算是一代

    d、多次执行a-c以完成多个iterations

2、Momentum

    momentum是动量的意思,其实就是对dw做一个一阶平均滤波,十分简单。通过一阶平均滤波,可以平滑dw的变化,也可以让dw加入滞后因子,到达谷点后因为滞后因素,依然会向前冲,如果对面是一个较低的山峰,说不定就冲过去了,也就是它能让网络在学习过程中跳出一些局部最优点。公式如下:

改进神经网络 改进神经网络的方法_改进神经网络_02

    beta是需要tune的超参数,一般设置为0.9,这里用beta1为了避免与后面的RMSprop的beta搞混。

    显然,新的dw会以(1-beta)的系数加入Vdw,并且Vdw也只会保留beta倍,所以,如果将多次迭代的式子展开,可以看到每次计算的dw的权值承指数衰减。若干次迭代后,比如t次,dw[1]的系数变为了beta**(t-1)*(1-beta),已经很小了,可以忽略。这里不再深入。到此足够理解和使用Momentum了。

    若Vdw一开始初始化为0,则Vdw一开始并不准确,会缓慢上升到准确值。原因是,比如,第一次计算Vdw=(1-beta)dw,该值明显小于dw,所以可以使用修正方法:

Vdw(corrected)=Vdw/(1 - beta**t) ,其中t是迭代次数

3、RMSprop

    Momentum是对导数dw和db做一阶惯性滤波,而RMSprop是对dw**2和db**2做一阶惯性滤波得到Sdw和Sdb,然后用用这些滤波结果对dw和db进行标准化,最后用标准化后的导数来更新梯度。公式如下:

改进神经网络 改进神经网络的方法_迭代_03

    beta2常常是一个很接近1的scalar,比如0.99、0.999,更新很缓慢的。从公式中可以看出,sqrt(Sdw)项其实起到了动态修改学习率的作用,若权值平方和过大,学习率相应会减小。

4、Adam

    Adam可谓目前表现最好的优化方法之一,它综合了Momentum和RMSprop的优点,将Momentum的V除以RMSprop的S,就是Adam了,也就是对一阶惯性滤波后的dw进行标准化(准标准化)。公式:

改进神经网络 改进神经网络的方法_改进神经网络_04

db公式类似。

    corrected主要是为了避免一开始时候V不正确引入的。t是Adam迭代次数,随着t增加,corrected的分母趋于1.