目录

  • 直观感受梯度消失和爆炸(特例)
  • 数学感受梯度消失和梯度爆炸
  • 简单回忆 反向传播(BP) 的流程:
  • 简单回忆 SimpleRNN 模型:
  • 开始BPTT
  • 干掉它:)
  • 简单回忆LSTM
  • LSTM中的BPTT
  • 缓解梯度消失/爆炸问题




LSTM现在都已经成为一个标准RNN,大家说RNN多半指的是LSTM,而最开始的RNN多称为

Simple RNN。所以本文主要是对于SimpleRNN为什么会存在

梯度消失/爆炸问题进行说明。

RNN 随着epoch增加,通常情况下是 loss 越来越小(下图蓝色),但是有的时候 loss 抖动的特别厉害,然后爆掉(下图绿色),这就是发生了梯度消失/爆炸。

RNN梯度计算 rnn梯度消失_RNN梯度计算

直观感受梯度消失和爆炸(特例)

下方是一个极其简单的RNN,1000 个输入(时间步),除了第一个输入为 1 以外,其余均为 0。中间RNN函数为线性函数 y = x,没有bias,输入和输出的参数都为 1,输入给下一个 RNN 的参数为 w。那么根据公式

RNN梯度计算 rnn梯度消失_神经网络_02

可以得到最后 RNN梯度计算 rnn梯度消失_自然语言处理_03 的输出为 RNN梯度计算 rnn梯度消失_自然语言处理_04

RNN梯度计算 rnn梯度消失_神经网络_05

当我们尝试去改变参数 RNN梯度计算 rnn梯度消失_rnn_06 的时候,我们会发现,只要参数有一点变化,对于长句子(长输入)来说,最后一向的输出会有很大的变化!

比如 RNN梯度计算 rnn梯度消失_rnn_06 从 1 变成 1.01,那么 最后的输出会从 1 变成约 20000,当 RNN梯度计算 rnn梯度消失_rnn_06 从 1 变成小于 1 的数字时,最后输出直接约等于 0。

梯度,我们也可以理解为 直观地看到参数 RNN梯度计算 rnn梯度消失_rnn_06 的变化会引起最后输出多大的变化。

RNN梯度计算 rnn梯度消失_自然语言处理_10


因此上图中绿色部分 RNN梯度计算 rnn梯度消失_自然语言处理_11

数学感受梯度消失和梯度爆炸

简单回忆 反向传播(BP) 的流程:

这是一个简单的神经网络,其中 RNN梯度计算 rnn梯度消失_rnn_12 代表第 2 层参数链接了前面第 2 个神经元和后面第 1 个神经元的参数;RNN梯度计算 rnn梯度消失_神经网络_13分别表示第一层第一个激活函数的输入值和输出值。此处所有激活函数为sigmoid函数。

RNN梯度计算 rnn梯度消失_自然语言处理_14

  • 误差函数 RNN梯度计算 rnn梯度消失_rnn_15
  • 参数迭代公式
    RNN梯度计算 rnn梯度消失_反向传播_16
  • RNN梯度计算 rnn梯度消失_自然语言处理_17 为例,最后一层参数更新公式为:
    RNN梯度计算 rnn梯度消失_自然语言处理_18
    推导:RNN梯度计算 rnn梯度消失_自然语言处理_19
    其中RNN梯度计算 rnn梯度消失_反向传播_20
  • RNN梯度计算 rnn梯度消失_RNN梯度计算_21 其它层的参数更新公式为:
    RNN梯度计算 rnn梯度消失_RNN梯度计算_22
    推导:RNN梯度计算 rnn梯度消失_神经网络_23
    这个推导没有出现一个神经元会有两个 RNN梯度计算 rnn梯度消失_rnn_24 共同带来的误差。如果有多个的话,这要把每个 RNN梯度计算 rnn梯度消失_rnn_24

题外话:基本上所有的均方误差损失函数(MSE) 都是以下形式表现:RNN梯度计算 rnn梯度消失_反向传播_26
但后来突然发现还有将RNN梯度计算 rnn梯度消失_RNN梯度计算_27变为RNN梯度计算 rnn梯度消失_rnn_28作为分母使用的(参考微软github
题外话结束

简单回忆 SimpleRNN 模型:

RNN梯度计算 rnn梯度消失_自然语言处理_29

RNN梯度计算 rnn梯度消失_rnn_30


好的,回忆完了,RNN会有多个输出,这里如果要使用BP的话一定是要考虑到时间的,因此RNN的BP叫做时间反向传播-BPTT(Back Propagation Through Time)

开始BPTT

参考链接

再从最简单的例子开始,下图是只有三个输入的RNN,没有任何激活函数:

RNN梯度计算 rnn梯度消失_神经网络_31

RNN梯度计算 rnn梯度消失_反向传播_32


那在 RNN梯度计算 rnn梯度消失_rnn_33 的时刻,误差函数为 RNN梯度计算 rnn梯度消失_rnn_34

一次训练下所有的误差值为单个误差之和:

RNN梯度计算 rnn梯度消失_神经网络_35

而我们的目标是去更新所有的参数 RNN梯度计算 rnn梯度消失_rnn_36,所以需要计算误差项的梯度。

RNN中误差项的梯度并更新参数:

RNN梯度计算 rnn梯度消失_神经网络_37

此处,我们对 RNN梯度计算 rnn梯度消失_rnn_38 时刻入手开始更新:

RNN梯度计算 rnn梯度消失_rnn_39

如果从 t=3 时刻开始,那么就需要每一次向后传递时,分一部分给RNN梯度计算 rnn梯度消失_RNN梯度计算_40再分一部分错误给后面。

RNN梯度计算 rnn梯度消失_RNN梯度计算_41

RNN梯度计算 rnn梯度消失_rnn_42

对上述偏导公式进行总结,得出所有时刻的梯度之和

RNN梯度计算 rnn梯度消失_rnn_43

RNN梯度计算 rnn梯度消失_神经网络_44 同上。

因为上述式子是假设没有任何激活函数,下式是任意时刻的梯度传递到时间步1时候的公式

RNN梯度计算 rnn梯度消失_RNN梯度计算_45

因此,

没有任何激活函数的情况下 RNN梯度计算 rnn梯度消失_神经网络_46RNN梯度计算 rnn梯度消失_自然语言处理_47RNN梯度计算 rnn梯度消失_神经网络_44相乘。那么RNN梯度计算 rnn梯度消失_神经网络_44的大小就会影响梯度爆炸还是消失。

有激活函数 :RNN梯度计算 rnn梯度消失_反向传播_50

求偏导则先求 RNN梯度计算 rnn梯度消失_rnn_51的导,再求 RNN梯度计算 rnn梯度消失_RNN梯度计算_52 的导: RNN梯度计算 rnn梯度消失_神经网络_53

RNN梯度计算 rnn梯度消失_自然语言处理_54

RNN梯度计算 rnn梯度消失_神经网络_55 取值在[0,1],后面还是有 RNN梯度计算 rnn梯度消失_神经网络_44 !可恶。
(补充: 激活函数 RNN梯度计算 rnn梯度消失_神经网络_57)

重点!!!
所以到这里就可以发现,只要我时间步长够长,就会有越来越多的 RNN梯度计算 rnn梯度消失_神经网络_44 相乘,如果说 RNN梯度计算 rnn梯度消失_神经网络_44 正好也在 [0,1] ,配合上 RNN梯度计算 rnn梯度消失_神经网络_55,我们的梯度就消失了(所以我们可以认为是RNN梯度计算 rnn梯度消失_神经网络_55RNN梯度计算 rnn梯度消失_神经网络_44两者一起导致梯度消失)。
但当然,因为变量终究还是 RNN梯度计算 rnn梯度消失_神经网络_44 (可能会大于1,但RNN梯度计算 rnn梯度消失_神经网络_55肯定小于等于1),如果 RNN梯度计算 rnn梯度消失_神经网络_44 大的能抵消掉 那么多个RNN梯度计算 rnn梯度消失_神经网络_55相乘 的时候,就会造成梯度爆炸

如果梯度消失,那么
RNN梯度计算 rnn梯度消失_神经网络_67
这也就导致我们的参数在合理的时间内就没怎么更新了。
RNN梯度计算 rnn梯度消失_RNN梯度计算_68

干掉它:)

(参考链接)

现在知道了梯度消失和爆炸的问题就在于RNN梯度计算 rnn梯度消失_反向传播_69
中的 RNN梯度计算 rnn梯度消失_神经网络_46,最直观的想法就是让它乘来乘去一直约为 1 或者 一直约为 0 ,这样就不会对整体的梯度有很大的影响。

LSTM 可以解决。(Clockwise RNN 和 SCRN 也可以,但这里不讲了)

简单回忆LSTM

(一个非常详细的LSTM介绍)

下面是三个时间步长的LSTM,RNN梯度计算 rnn梯度消失_RNN梯度计算_71时的输入是RNN梯度计算 rnn梯度消失_rnn_72 当前输入与上一个的输出相结合。三个橙色的 RNN梯度计算 rnn梯度消失_神经网络_73 函数是LSTM的三个gate,从左至右分别为遗忘门、输入门以及输出门。

RNN梯度计算 rnn梯度消失_神经网络_74

遗忘门 根据新的输入,经过激活函数之后,得到一个 0 - 1 的值,这个值决定了过去的记忆 RNN梯度计算 rnn梯度消失_神经网络_75

RNN梯度计算 rnn梯度消失_神经网络_76

RNN梯度计算 rnn梯度消失_神经网络_77

输入门 控制了新输入有多少要被加入到记忆中,但这里的输出还需要配合上神经网络觉得新输入中有用的部分。

RNN梯度计算 rnn梯度消失_反向传播_78

这里会拆成以下两个矩阵进行对应相乘:

RNN梯度计算 rnn梯度消失_RNN梯度计算_79

RNN梯度计算 rnn梯度消失_rnn_80

输出门 控制了有多少信息会被编辑到记忆细胞,作为下一个时间步的输入。

RNN梯度计算 rnn梯度消失_自然语言处理_81

RNN梯度计算 rnn梯度消失_rnn_82


记忆细胞 的内容就根据上述结果进行更新:

RNN梯度计算 rnn梯度消失_RNN梯度计算_83

题外话:看到这里你会发现为什么LSTM会用到两个 RNN梯度计算 rnn梯度消失_RNN梯度计算_84 函数?
找了半天就这个链接点赞第二多的比较有道理(点赞最多的那个我实在不懂tanh怎么就有他说的多阶导能很长时间不为0的表现)。
大概的意思就是RNN梯度计算 rnn梯度消失_RNN梯度计算_84的值在 [-1,1] 之间,但 sigmoid 在 [0,1] 之间,所以使用 RNN梯度计算 rnn梯度消失_RNN梯度计算_84 可以生成负值。我了解到在神经网络中,像 RNN梯度计算 rnn梯度消失_RNN梯度计算_84 这样的0中心激活函数可以加快收敛速度(关于0中心这篇不错),那可能也是这个原因我们在这里使用 RNN梯度计算 rnn梯度消失_RNN梯度计算_84,应该是其他激活函数也可用(但这里不是百分百保证,最好对于激活函数的区别再了解一下)。
题外话结束

LSTM中的BPTT

总结一下 LSTM 里发生的公式们(RNN梯度计算 rnn梯度消失_自然语言处理_89 表示矩阵相乘,RNN梯度计算 rnn梯度消失_rnn_90表示矩阵元素对应相乘):
RNN梯度计算 rnn梯度消失_反向传播_91

现在我们放一个时间步长为3的 LSTM :

RNN梯度计算 rnn梯度消失_反向传播_92


列出涉及到的式子(嵌套的就不打开了,太多太乱了):

RNN梯度计算 rnn梯度消失_神经网络_93

现在要去更新参数,共计四个RNN梯度计算 rnn梯度消失_自然语言处理_94

RNN梯度计算 rnn梯度消失_神经网络_95

求偏导过程中,主要看函数中有RNN梯度计算 rnn梯度消失_rnn_96的是哪部分,很快我们发现RNN梯度计算 rnn梯度消失_RNN梯度计算_97函数中直观的包含了RNN梯度计算 rnn梯度消失_反向传播_98RNN梯度计算 rnn梯度消失_rnn_96,因此RNN梯度计算 rnn梯度消失_反向传播_98RNN梯度计算 rnn梯度消失_rnn_96的偏导公式都相同。

再看参数 RNN梯度计算 rnn梯度消失_神经网络_102,它与RNN梯度计算 rnn梯度消失_自然语言处理_103是直接关系。

RNN梯度计算 rnn梯度消失_自然语言处理_104

因此我们也可以总结出和RNN类似的公式,即任意时刻误差项公式

RNN梯度计算 rnn梯度消失_反向传播_105

RNN梯度计算 rnn梯度消失_神经网络_106 同上。

接着看到 RNN梯度计算 rnn梯度消失_神经网络_107,举个t=2的例子:

RNN梯度计算 rnn梯度消失_RNN梯度计算_108

注意!我们之所以不直接把RNN梯度计算 rnn梯度消失_反向传播_109提出来,是因为 RNN梯度计算 rnn梯度消失_rnn_110 中是包含了 RNN梯度计算 rnn梯度消失_反向传播_109 的!!RNN梯度计算 rnn梯度消失_rnn_110包含了RNN梯度计算 rnn梯度消失_自然语言处理_113,而RNN梯度计算 rnn梯度消失_自然语言处理_113包含了RNN梯度计算 rnn梯度消失_反向传播_109
作为简单的记忆,我们就把RNN梯度计算 rnn梯度消失_反向传播_116拆成了四项,除了RNN梯度计算 rnn梯度消失_rnn_117那一项,其他都是一个套路。

配上一个可视的BPTT

RNN梯度计算 rnn梯度消失_反向传播_118


总结一下偏导公式:

RNN梯度计算 rnn梯度消失_rnn_119

如果把这四项分别用 RNN梯度计算 rnn梯度消失_神经网络_120替代的话,公式就可以变成:

RNN梯度计算 rnn梯度消失_反向传播_121

把这个简洁的公式带入之前的误差公式:
RNN梯度计算 rnn梯度消失_RNN梯度计算_122

缓解梯度消失/爆炸问题

有连乘,那就说明有可能造成梯度消失和爆炸。上文也讲了RNN梯度计算 rnn梯度消失_神经网络_123里面有什么,总共四项,如果看的云里雾里也没事,因为那个 RNN梯度计算 rnn梯度消失_反向传播_124 你一定看的懂!因为 RNN梯度计算 rnn梯度消失_反向传播_124 只有一个内容 RNN梯度计算 rnn梯度消失_rnn_117,我们可以轻松地直观地通过他调整 RNN梯度计算 rnn梯度消失_rnn_117

接下来我们看 RNN梯度计算 rnn梯度消失_rnn_117 到底怎么能帮助我们。现在假设对某一个时间步 RNN梯度计算 rnn梯度消失_rnn_129,我们有:
RNN梯度计算 rnn梯度消失_rnn_130
然后为了梯度不消失,我们可以再时间步 RNN梯度计算 rnn梯度消失_自然语言处理_131 找到一个合适的 RNN梯度计算 rnn梯度消失_rnn_96 使得:
RNN梯度计算 rnn梯度消失_自然语言处理_133
由于遗忘门的激活函数和梯度项中大家都是相加的(A,B,C,D,加性结构),所以使得 LSTM 在任何时间步都能找到这样的 RNN梯度计算 rnn梯度消失_RNN梯度计算_134
RNN梯度计算 rnn梯度消失_自然语言处理_135
这样梯度就不会消失了。

另一个重要的性质: 正如上文说到的加性结构,四个项可以相互平衡从而保证在反向传播的时候梯度值不会消失。

举个例子:假设 时间步RNN梯度计算 rnn梯度消失_rnn_136,我们对梯度值中的四项设置一个相互平衡的值(从而保证梯度不消失):
RNN梯度计算 rnn梯度消失_自然语言处理_137
带入连乘公式:
RNN梯度计算 rnn梯度消失_rnn_138
这时候就算是连乘,梯度也不会消失了。

所以,在 LSTM 中,遗忘门的存在,以及细胞状态梯度的加性特性,使网络能够以这样一种方式更新参数,即不同子梯度之间的平衡从而避免梯度消失
但看到这,也就清楚了,因为我们都是正数相加,所以不能够避免梯度爆炸,当 RNN梯度计算 rnn梯度消失_rnn_139的数值很大的时候,RNN梯度计算 rnn梯度消失_rnn_117