目录
- 直观感受梯度消失和爆炸(特例)
- 数学感受梯度消失和梯度爆炸
- 简单回忆 反向传播(BP) 的流程:
- 简单回忆 SimpleRNN 模型:
- 开始BPTT
- 干掉它:)
- 简单回忆LSTM
- LSTM中的BPTT
- 缓解梯度消失/爆炸问题
LSTM现在都已经成为一个标准RNN,大家说RNN多半指的是LSTM,而最开始的RNN多称为
Simple RNN。所以本文主要是对于SimpleRNN为什么会存在
梯度消失/爆炸问题进行说明。
RNN 随着epoch增加,通常情况下是 loss 越来越小(下图蓝色),但是有的时候 loss 抖动的特别厉害,然后爆掉(下图绿色),这就是发生了梯度消失/爆炸。

直观感受梯度消失和爆炸(特例)
下方是一个极其简单的RNN,1000 个输入(时间步),除了第一个输入为 1 以外,其余均为 0。中间RNN函数为线性函数 y = x,没有bias,输入和输出的参数都为 1,输入给下一个 RNN 的参数为 w。那么根据公式
可以得到最后 的输出为
。

当我们尝试去改变参数 的时候,我们会发现,只要参数有一点变化,对于长句子(长输入)来说,最后一向的输出会有很大的变化!
比如 从 1 变成 1.01,那么 最后的输出会从 1 变成约 20000,当
从 1 变成小于 1 的数字时,最后输出直接约等于 0。
梯度,我们也可以理解为 直观地看到参数 的变化会引起最后输出多大的变化。

因此上图中绿色部分
数学感受梯度消失和梯度爆炸
简单回忆 反向传播(BP) 的流程:
这是一个简单的神经网络,其中 代表第 2 层参数链接了前面第 2 个神经元和后面第 1 个神经元的参数;
分别表示第一层第一个激活函数的输入值和输出值。此处所有激活函数为sigmoid函数。

- 误差函数
- 参数迭代公式
- 以
为例,最后一层参数更新公式为:
推导:
其中 - 以
其它层的参数更新公式为:
推导:
这个推导没有出现一个神经元会有两个共同带来的误差。如果有多个的话,这要把每个
题外话:基本上所有的均方误差损失函数(MSE) 都是以下形式表现:
但后来突然发现还有将变为
作为分母使用的(参考微软github)
题外话结束
简单回忆 SimpleRNN 模型:

好的,回忆完了,RNN会有多个输出,这里如果要使用BP的话一定是要考虑到时间的,因此RNN的BP叫做时间反向传播-BPTT(Back Propagation Through Time)
开始BPTT
(参考链接)
再从最简单的例子开始,下图是只有三个输入的RNN,没有任何激活函数:

那在 的时刻,误差函数为
。
一次训练下所有的误差值为单个误差之和:
而我们的目标是去更新所有的参数 ,所以需要计算误差项的梯度。
RNN中误差项的梯度并更新参数:
此处,我们对 时刻入手开始更新:
如果从 t=3 时刻开始,那么就需要每一次向后传递时,分一部分给再分一部分错误给后面。

对上述偏导公式进行总结,得出所有时刻的梯度之和:
同上。
因为上述式子是假设没有任何激活函数,下式是任意时刻的梯度传递到时间步1时候的公式:
因此,
在没有任何激活函数的情况下 是
个
相乘。那么
的大小就会影响梯度爆炸还是消失。
若有激活函数 :
求偏导则先求 的导,再求
的导:

取值在[0,1],后面还是有
!可恶。
(补充: 激活函数 )
重点!!!
所以到这里就可以发现,只要我时间步长够长,就会有越来越多的 相乘,如果说
正好也在 [0,1] ,配合上
,我们的梯度就消失了(所以我们可以认为是
和
两者一起导致梯度消失)。
但当然,因为变量终究还是 (可能会大于1,但
肯定小于等于1),如果
大的能抵消掉 那么多个
相乘 的时候,就会造成梯度爆炸。
如果梯度消失,那么
这也就导致我们的参数在合理的时间内就没怎么更新了。
干掉它:)
(参考链接)
现在知道了梯度消失和爆炸的问题就在于
中的 ,最直观的想法就是让它乘来乘去一直约为 1 或者 一直约为 0 ,这样就不会对整体的梯度有很大的影响。
LSTM 可以解决。(Clockwise RNN 和 SCRN 也可以,但这里不讲了)
简单回忆LSTM
(一个非常详细的LSTM介绍)
下面是三个时间步长的LSTM,时的输入是
当前输入与上一个的输出相结合。三个橙色的
函数是LSTM的三个gate,从左至右分别为遗忘门、输入门以及输出门。

遗忘门 根据新的输入,经过激活函数之后,得到一个 0 - 1 的值,这个值决定了过去的记忆

输入门 控制了新输入有多少要被加入到记忆中,但这里的输出还需要配合上神经网络觉得新输入中有用的部分。
这里会拆成以下两个矩阵进行对应相乘:

输出门 控制了有多少信息会被编辑到记忆细胞,作为下一个时间步的输入。

记忆细胞 的内容就根据上述结果进行更新:
题外话:看到这里你会发现为什么LSTM会用到两个 函数?
找了半天就这个链接中点赞第二多的比较有道理(点赞最多的那个我实在不懂tanh怎么就有他说的多阶导能很长时间不为0的表现)。
大概的意思就是的值在 [-1,1] 之间,但 sigmoid 在 [0,1] 之间,所以使用
可以生成负值。我了解到在神经网络中,像
这样的0中心激活函数可以加快收敛速度(关于0中心这篇不错),那可能也是这个原因我们在这里使用
,应该是其他激活函数也可用(但这里不是百分百保证,最好对于激活函数的区别再了解一下)。
题外话结束
LSTM中的BPTT
总结一下 LSTM 里发生的公式们( 表示矩阵相乘,
表示矩阵元素对应相乘):
现在我们放一个时间步长为3的 LSTM :

列出涉及到的式子(嵌套的就不打开了,太多太乱了):
现在要去更新参数,共计四个:
求偏导过程中,主要看函数中有的是哪部分,很快我们发现
函数中直观的包含了
和
,因此
和
的偏导公式都相同。
再看参数 ,它与
是直接关系。
因此我们也可以总结出和RNN类似的公式,即任意时刻误差项公式:
同上。
接着看到 ,举个t=2的例子:
注意!我们之所以不直接把提出来,是因为
中是包含了
的!!
包含了
,而
包含了
。
作为简单的记忆,我们就把拆成了四项,除了
那一项,其他都是一个套路。
配上一个可视的BPTT

总结一下偏导公式:
如果把这四项分别用 替代的话,公式就可以变成:
把这个简洁的公式带入之前的误差公式:
缓解梯度消失/爆炸问题
有连乘,那就说明有可能造成梯度消失和爆炸。上文也讲了里面有什么,总共四项,如果看的云里雾里也没事,因为那个
你一定看的懂!因为
只有一个内容
,我们可以轻松地直观地通过他调整
接下来我们看 到底怎么能帮助我们。现在假设对某一个时间步
,我们有:
然后为了梯度不消失,我们可以再时间步 找到一个合适的
使得:
由于遗忘门的激活函数和梯度项中大家都是相加的(A,B,C,D,加性结构),所以使得 LSTM 在任何时间步都能找到这样的
这样梯度就不会消失了。
另一个重要的性质: 正如上文说到的加性结构,四个项可以相互平衡从而保证在反向传播的时候梯度值不会消失。
举个例子:假设 时间步,我们对梯度值中的四项设置一个相互平衡的值(从而保证梯度不消失):
带入连乘公式:
这时候就算是连乘,梯度也不会消失了。
所以,在 LSTM 中,遗忘门的存在,以及细胞状态梯度的加性特性,使网络能够以这样一种方式更新参数,即不同子梯度之间的平衡从而避免梯度消失。
但看到这,也就清楚了,因为我们都是正数相加,所以不能够避免梯度爆炸,当 的数值很大的时候,
















