引言

LSTM是RNN的变种,是为了解决RNN存在的长期依赖问题而专门设计出来的。所谓长期依赖问题是,后面的单词在很长的时间序列后还依赖前面的单词,但由于梯度消失问题,导致前面的单词无法影响到后面的单词。

LSTM单元

LSTM背后的数学原理_长短期记忆网络

LSTM单元(cell)在每个时间点更新单元状态LSTM背后的数学原理_长短期记忆网络_02,它决定了LSTM背后的数学原理_权重_03的值。LSTM有更新门、遗忘门和输出门来控制这些值。

下面来对LSTM中的元素做一些说明

遗忘门

遗忘门用来控制内存中之前的状态是否会被遗忘掉。

  • 如果遗忘门的值是0,LSTM会遗忘(忽略)之前的状态
  • 如果遗忘门的值是1,LSTM会记得(保持)之前的状态
  • 如果是0到1之间的值,代表LSTM会记得之前的状态多大程度

公式为:

LSTM背后的数学原理_LSTM_04

  • LSTM背后的数学原理_反向传播_05LSTM背后的数学原理_LSTM_06是可学习的权重和偏差
  • 通过sigmoid函数来保证输出的值在[0,1]之间
  • 遗忘门LSTM背后的数学原理_反向传播_07与之前单元状态LSTM背后的数学原理_LSTM_08同维度,即它们能逐元素相乘

在代码中​​Wf​​​代表LSTM背后的数学原理_取值_09,​​​bf​​​代表LSTM背后的数学原理_取值_10,​​​ft​​​代表LSTM背后的数学原理_LSTM_11

候选值 c ~ ⟨ t ⟩ \tilde{\mathbf{c}}^{\langle t \rangle} c~⟨t⟩

  • 候选值保存的是当前时间点可能会存入当前单元状态(LSTM背后的数学原理_LSTM_08)的信息
  • 候选值能多大程度的存入当前单元状态取决于更新门

公式为:

LSTM背后的数学原理_长短期记忆网络_13

这里用的是tanh函数,所以取值范围为[-1,1]

​cct​​​代表LSTM背后的数学原理_反向传播_14
​​​Wc​​​代表LSTM背后的数学原理_长短期记忆网络_15

更新门(输入门)

  • 更新门决定候选值(哪些维度)能多大程度的存入当前单元状态
  • 如果更新门的值是0,意味着防止候选值存入单元状态
  • 如果更新门的值是1,意味着完全允许候选值存入单元状态

有些文献称它为输入门,并且用"i"来表示,这里沿用这种约定

公式:

LSTM背后的数学原理_长短期记忆网络_16

​Wi​​​代表LSTM背后的数学原理_权重_17,​​​bi​​​代表LSTM背后的数学原理_取值_18,​​​it​​​代表更新门LSTM背后的数学原理_LSTM_19

单元状态 c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩

  • 单元状态是时间序列间传递的"记忆"
  • 新的单元状态由之前的状态和当前候选值组成

公式:

LSTM背后的数学原理_权重_20

  • 结合上面所有的公式,得到了单元状态的计算公式
  • 前一单元状态由遗忘门控制会有多少被保存到当前单元状态中
  • 候选值由更新门控制能有多少被保存到当前单元状态中

​c​​​:所有时间点的单元状态LSTM背后的数学原理_反向传播_21,形状是LSTM背后的数学原理_长短期记忆网络_22

​c_next​​​:当前单元状态LSTM背后的数学原理_长短期记忆网络_02,形状LSTM背后的数学原理_长短期记忆网络_24

​c_prev​​​: 前一个单元状态LSTM背后的数学原理_长短期记忆网络_25,形状LSTM背后的数学原理_长短期记忆网络_24

输出门 Γ o \mathbf{\Gamma}_{o} Γo​

  • 输出门控制了当前时间点能输出什么
  • 和之前所有门一样,取值范围[0,1]

公式:
LSTM背后的数学原理_长短期记忆网络_27

​W_o​​​代表输出门的权重LSTM背后的数学原理_LSTM_28,​​​bo​​​代表输出门的偏差LSTM背后的数学原理_取值_29,​​​ot​​​代表输出门LSTM背后的数学原理_权重_30

从三个门的公式可以看出,它们的激活函数都是sigmoid,取值都是[0,1],输入都是LSTM背后的数学原理_权重_31LSTM背后的数学原理_LSTM_32,唯一的区别是可学习的权重和偏差不一样。如果取值为0,表示这个门是关闭的;取值为1,表示这个门是完全打开的;取值LSTM背后的数学原理_LSTM_33表示这个门是半关半开的,只允许一部分的值进入(被保存,被传递)。

隐藏状态

  • 当前的隐藏状态会传递到下一个时间点的LSTM单元
  • 它用于决定下个时间点的三个门
  • 同时也用于当前时间点的预测(输出值LSTM背后的数学原理_反向传播_34)

公式:
LSTM背后的数学原理_权重_35

  • 隐藏状态由单元状态和输出门决定
  • 单元状态传递到tanh函数得到LSTM背后的数学原理_权重_36的取值

​a​​​: 所有的隐藏状态LSTM背后的数学原理_反向传播_37,形状LSTM背后的数学原理_反向传播_38

​a_prev​​​: 上个时间点的隐藏状态LSTM背后的数学原理_长短期记忆网络_39,形状LSTM背后的数学原理_长短期记忆网络_24

​a_next​​​: 当前时间点的隐藏状态LSTM背后的数学原理_权重_03,形状LSTM背后的数学原理_长短期记忆网络_24

预测值 y ^ ⟨ t ⟩ \hat y^{\langle t \rangle} y^​⟨t⟩

  • 在分类问题中的输出值使用softmax函数
    LSTM背后的数学原理_权重_43

LSTM背后的数学原理_长短期记忆网络_44

​y_pred​​​: 所有时间点的预测值LSTM背后的数学原理_取值_45,形状LSTM背后的数学原理_反向传播_46
​​​yt_pred​​​: 当前时间点的预测值LSTM背后的数学原理_取值_47,形状LSTM背后的数学原理_权重_48

至此我们知道了LSTM单元中的所有计算公式,下面来看如何实现前向传播和反向传播。

前向传播

LSTM背后的数学原理_长短期记忆网络_49

实现如上图所示的前向传播过程,我们需要代码化上面的公式LSTM背后的数学原理_权重_50~LSTM背后的数学原理_取值_51

要注意的是,我们会叠加前一个隐藏状态LSTM背后的数学原理_长短期记忆网络_39和当前的输入LSTM背后的数学原理_LSTM_53到一个矩阵​​​concat​​:

LSTM背后的数学原理_权重_54

反向传播

LSTM的反向传播比RNN的要复杂一点。不过遵循规则——求某个节点的梯度时,考虑该节点的所有输出节点。分别计算每个输出节点的梯度乘上输出节点对该节点的梯度,然后加起来就得到该节点的梯度,也不难。

LSTM背后的数学原理_LSTM_55

首先列出激活函数的导数:
LSTM背后的数学原理_长短期记忆网络_56
LSTM背后的数学原理_长短期记忆网络_57

假设考虑的LSTM结构为多对多的,且LSTM背后的数学原理_反向传播_58,每个时刻LSTM背后的数学原理_长短期记忆网络_59都有一个输出及一个损失LSTM背后的数学原理_取值_60,全局损失函数为:
LSTM背后的数学原理_取值_61

我们求LSTM背后的数学原理_长短期记忆网络_62LSTM背后的数学原理_权重_63的导数LSTM背后的数学原理_LSTM_64,具体过程可以参考博客 ​​​Softmax与Cross-entropy的求导​​,得到:

LSTM背后的数学原理_反向传播_65

根据公式LSTM背后的数学原理_取值_51,可以很容易的求出:

LSTM背后的数学原理_反向传播_67
LSTM背后的数学原理_长短期记忆网络_68

而求LSTM背后的数学原理_反向传播_69LSTM背后的数学原理_反向传播_70时要分两种情况考虑:

在时刻LSTM背后的数学原理_取值_71时,
LSTM背后的数学原理_LSTM_72

LSTM背后的数学原理_权重_73

在时刻LSTM背后的数学原理_取值_74时,LSTM背后的数学原理_权重_03的后续同时有LSTM背后的数学原理_LSTM_76(大于LSTM背后的数学原理_长短期记忆网络_59时刻的误差)和LSTM背后的数学原理_LSTM_78(LSTM背后的数学原理_长短期记忆网络_59时刻的误差)两个节点。因此计算梯度时要考虑这两部分:

LSTM背后的数学原理_LSTM_80

在这一步反向传播计算的难点在于LSTM背后的数学原理_LSTM_81

LSTM背后的数学原理_LSTM_82

因为LSTM背后的数学原理_权重_03受到上图这四部分所影响,而这四部分都和LSTM背后的数学原理_长短期记忆网络_39有关。所以LSTM背后的数学原理_LSTM_81的计算结果也由四部分组成(公式LSTM背后的数学原理_取值_86):

LSTM背后的数学原理_权重_87

上面有一个公共项LSTM背后的数学原理_长短期记忆网络_88

在时刻LSTM背后的数学原理_取值_74时,LSTM背后的数学原理_长短期记忆网络_02的梯度也是由当前时刻的误差以及LSTM背后的数学原理_LSTM_91时刻的误差组成(由公式LSTM背后的数学原理_取值_92)得:
LSTM背后的数学原理_反向传播_93


现在求对LSTM背后的数学原理_权重_94的梯度就简单了。
LSTM背后的数学原理_权重_95
LSTM背后的数学原理_长短期记忆网络_96

LSTM背后的数学原理_取值_97

LSTM背后的数学原理_LSTM_98

LSTM背后的数学原理_长短期记忆网络_99
LSTM背后的数学原理_长短期记忆网络_100

LSTM背后的数学原理_取值_101

LSTM背后的数学原理_权重_102

参考

  1. ​LSTM模型与前向反向传播算法​
  2. ​从零实现循环神经网路​
  3. 吴恩达课程