引言
LSTM是RNN的变种,是为了解决RNN存在的长期依赖问题而专门设计出来的。所谓长期依赖问题是,后面的单词在很长的时间序列后还依赖前面的单词,但由于梯度消失问题,导致前面的单词无法影响到后面的单词。
LSTM单元
LSTM单元(cell)在每个时间点更新单元状态,它决定了的值。LSTM有更新门、遗忘门和输出门来控制这些值。
下面来对LSTM中的元素做一些说明
遗忘门
遗忘门用来控制内存中之前的状态是否会被遗忘掉。
- 如果遗忘门的值是0,LSTM会遗忘(忽略)之前的状态
- 如果遗忘门的值是1,LSTM会记得(保持)之前的状态
- 如果是0到1之间的值,代表LSTM会记得之前的状态多大程度
公式为:
- 和是可学习的权重和偏差
- 通过sigmoid函数来保证输出的值在[0,1]之间
- 遗忘门与之前单元状态同维度,即它们能逐元素相乘
在代码中Wf
代表,bf
代表,ft
代表
候选值 c ~ ⟨ t ⟩ \tilde{\mathbf{c}}^{\langle t \rangle} c~⟨t⟩
- 候选值保存的是当前时间点可能会存入当前单元状态()的信息
- 候选值能多大程度的存入当前单元状态取决于更新门
公式为:
这里用的是tanh函数,所以取值范围为[-1,1]
cct
代表
Wc
代表
更新门(输入门)
- 更新门决定候选值(哪些维度)能多大程度的存入当前单元状态
- 如果更新门的值是0,意味着防止候选值存入单元状态
- 如果更新门的值是1,意味着完全允许候选值存入单元状态
有些文献称它为输入门,并且用"i"来表示,这里沿用这种约定
公式:
Wi
代表,bi
代表,it
代表更新门。
单元状态 c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩
- 单元状态是时间序列间传递的"记忆"
- 新的单元状态由之前的状态和当前候选值组成
公式:
- 结合上面所有的公式,得到了单元状态的计算公式
- 前一单元状态由遗忘门控制会有多少被保存到当前单元状态中
- 候选值由更新门控制能有多少被保存到当前单元状态中
c
:所有时间点的单元状态,形状是
c_next
:当前单元状态,形状
c_prev
: 前一个单元状态,形状
输出门 Γ o \mathbf{\Gamma}_{o} Γo
- 输出门控制了当前时间点能输出什么
- 和之前所有门一样,取值范围[0,1]
公式:
W_o
代表输出门的权重,bo
代表输出门的偏差,ot
代表输出门
从三个门的公式可以看出,它们的激活函数都是sigmoid,取值都是[0,1],输入都是和,唯一的区别是可学习的权重和偏差不一样。如果取值为0,表示这个门是关闭的;取值为1,表示这个门是完全打开的;取值表示这个门是半关半开的,只允许一部分的值进入(被保存,被传递)。
隐藏状态
- 当前的隐藏状态会传递到下一个时间点的LSTM单元
- 它用于决定下个时间点的三个门
- 同时也用于当前时间点的预测(输出值)
公式:
- 隐藏状态由单元状态和输出门决定
- 单元状态传递到tanh函数得到的取值
a
: 所有的隐藏状态,形状
a_prev
: 上个时间点的隐藏状态,形状
a_next
: 当前时间点的隐藏状态,形状
预测值 y ^ ⟨ t ⟩ \hat y^{\langle t \rangle} y^⟨t⟩
- 在分类问题中的输出值使用softmax函数
y_pred
: 所有时间点的预测值,形状
yt_pred
: 当前时间点的预测值,形状
至此我们知道了LSTM单元中的所有计算公式,下面来看如何实现前向传播和反向传播。
前向传播
实现如上图所示的前向传播过程,我们需要代码化上面的公式~。
要注意的是,我们会叠加前一个隐藏状态和当前的输入到一个矩阵concat
:
反向传播
LSTM的反向传播比RNN的要复杂一点。不过遵循规则——求某个节点的梯度时,考虑该节点的所有输出节点。分别计算每个输出节点的梯度乘上输出节点对该节点的梯度,然后加起来就得到该节点的梯度,也不难。
首先列出激活函数的导数:
假设考虑的LSTM结构为多对多的,且,每个时刻都有一个输出及一个损失,全局损失函数为:
我们求对的导数,具体过程可以参考博客 Softmax与Cross-entropy的求导,得到:
根据公式,可以很容易的求出:
而求和时要分两种情况考虑:
在时刻时,
在时刻时,的后续同时有(大于时刻的误差)和(时刻的误差)两个节点。因此计算梯度时要考虑这两部分:
在这一步反向传播计算的难点在于。
因为受到上图这四部分所影响,而这四部分都和有关。所以的计算结果也由四部分组成(公式):
上面有一个公共项
在时刻时,的梯度也是由当前时刻的误差以及时刻的误差组成(由公式)得:
现在求对的梯度就简单了。
参考
- LSTM模型与前向反向传播算法
- 从零实现循环神经网路
- 吴恩达课程