大家好,我是半虹,这篇文章来讲长短期记忆网络 (Long Short-Term Memory, LSTM)

文章行文思路如下:

  1. 首先通过循环神经网络引出为啥需要长短期记忆网络
  2. 然后介绍长短期记忆网络的核心思想与运作方式
  3. 最后通过简短的代码深入理解长短期记忆网络的运作方式

长短期记忆网络可以看作是循环神经网络的改进版本,想要理解长短期记忆网络,首先要了解循环神经网络

由于我们之前已详细介绍过循环神经网络,所以这里我们只会做一个简单的回顾



对比前馈神经网络,循环神经网络通过增加隐状态实现对隐藏层信息的传递,以此达到记住历史输入的目的

网络在每个时间步里读取上一隐藏层输出作为当前隐藏层输入,并保存当前隐藏层输出作为下一隐藏层输入

其结构简图如下:


lstm神经网络光伏功率预测 lstm神经网络图_nlp

其中 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_02 是输入 ,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_03

下面我们把隐藏层中的细节也画出来,方便后面与长短期记忆网络来对比


lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_04

其中 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_02 是输入 ,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_03 是隐藏层的输出,图中的灰色矩形同样代表隐藏层,lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_07

对应的公式表达如下:
lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_08
其中 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_09 是当前输入,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_10 是当前隐藏层输出,lstm神经网络光伏功率预测 lstm神经网络图_lstm_11 是先前隐藏层输出,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_12lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_13lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_14



理论上,上述介绍的循环神经网络能处理任意长的序列,但实际上却并非如此

在实际应用循环神经网络处理长序列时通常会出现梯度爆炸或梯度消失的情况,导致网络难以捕捉长期依赖

这是为什么呢?通过简单分析一下梯度计算公式就能发现端倪

为了阐述方便,我们暂且假定所有的参数都是一维的,用字母 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_15 表示,对参数求导并按时间展开后如下所示
lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_16
不难发现,当前梯度 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_17 由当前梯度值 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_18 以及先前梯度 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_19 决定,对于先前梯度权重 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_20

  • lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_21
  • lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_22


由推导式可以看出,梯度爆炸和梯度消失更容易出现在与当前时间步距离更远的梯度

这是因为这些梯度的权重连乘项更多,举例来说,对于时间步 lstm神经网络光伏功率预测 lstm神经网络图_lstm_23,其梯度 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_17

  • 时间步 lstm神经网络光伏功率预测 lstm神经网络图_nlp_25 的梯度 lstm神经网络光伏功率预测 lstm神经网络图_nlp_26,与时间步 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_27 的距离为 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_28,其权重为 lstm神经网络光伏功率预测 lstm神经网络图_nlp_29
  • 时间步 lstm神经网络光伏功率预测 lstm神经网络图_nlp_30 的梯度 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_31,与时间步 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_27 的距离为 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_33,其权重为 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_34
  • 时间步 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_35 的梯度 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_31,与时间步 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_27 的距离为 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_38,其权重为 lstm神经网络光伏功率预测 lstm神经网络图_lstm_39
  • ……

这说明了什么?这说明了对于当前输入,距其更远的输入的梯度更容易出现梯度爆炸或梯度消失

从而导致长距离的梯度反馈失效,这就是循环神经网络难以捕捉长期依赖的实际含义



最后提醒大家注意一个细节,对于时间步 lstm神经网络光伏功率预测 lstm神经网络图_lstm_23 的梯度 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_17

  • 假设有且仅有最后一项梯度爆炸,那么就会导致整个梯度爆炸,因为 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_42
  • 假设有且仅有最后一项梯度消失,这并不会导致整个梯度消失,因为 lstm神经网络光伏功率预测 lstm神经网络图_nlp_43


总结一下,梯度反向传播时发生的异常,主要可以分为两种,一是梯度爆炸,二是梯度消失

梯度爆炸比较容易处理,一个简单但有效的做法是设置一个梯度阈值,当梯度超过这个阈值时直接截断

梯度消失更难处理一些,而现在流行的做法正是将循环神经网络替换成长短期记忆网络

注意,长短期记忆网络能缓解梯度消失的问题,但并不能缓解梯度爆炸的问题



上面我们从反向传播的角度解释了什么是梯度消失

如果我们从前向计算的角度来看,则梯度消失可以理解成隐状态对短期记忆敏感,对长期记忆作用有限

为了维持长期记忆,长短期记忆网络引入记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动

从直觉上来说,先前重要的记忆会保留在记忆元,不重要的记忆会被过滤,以此来达到长期记忆的目的



这里有两个概念需要解释,一是记忆元,二是门机制,这两个就是长短期记忆网络的核心

先说记忆元,可以理解成另一种隐状态,都是用来记录附加信息的,简称为单元,英文为 lstm神经网络光伏功率预测 lstm神经网络图_lstm_44

再说门机制,这是用来控制记忆元中信息流动的机制,具体来说包括三个控制门:

  • 输入门:控制是否将信息写入记忆元,英文为 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_45
  • 遗忘门:控制是否从记忆元丢弃信息,英文为 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_46
  • 输出门:控制是否从记忆元读出信息,英文为 lstm神经网络光伏功率预测 lstm神经网络图_lstm_47

本质上来说,上述三个控制门都是由一个线性层加一个激活函数组成的,这里激活函数用的是 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_48

因为这样能将输出限制在零到一之间,以表示门的打开程度,控制信息流动的程度



相比循环神经网络只有一个传输状态,即隐状态,长短期记忆网络有两个传输状态,即隐状态和记忆元

二者的输入输出对比图如下:


lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_49

其中 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_03 表示隐状态,lstm神经网络光伏功率预测 lstm神经网络图_lstm_51

首先,根据当前输入 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_09 和先前隐状态 lstm神经网络光伏功率预测 lstm神经网络图_lstm_11,计算得到输入门 lstm神经网络光伏功率预测 lstm神经网络图_nlp_54、遗忘门 lstm神经网络光伏功率预测 lstm神经网络图_nlp_55、输出门 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_56

其中,lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_57lstm神经网络光伏功率预测 lstm神经网络图_nlp_58lstm神经网络光伏功率预测 lstm神经网络图_nlp_59lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_60lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_61lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_62lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_63lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_64lstm神经网络光伏功率预测 lstm神经网络图_nlp_65 都是网络参数,lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_07lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_48 激活函数
lstm神经网络光伏功率预测 lstm神经网络图_nlp_68

然后,根据当前输入 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_09 和先前隐状态 lstm神经网络光伏功率预测 lstm神经网络图_lstm_11,计算得到候选记忆元 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_71

其中,lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_72lstm神经网络光伏功率预测 lstm神经网络图_lstm_73lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_74 都是网络参数,lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_75lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_75 激活函数
lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_77
接着,输入门 lstm神经网络光伏功率预测 lstm神经网络图_nlp_54 控制采用多少来自 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_71 的新信息,遗忘门 lstm神经网络光伏功率预测 lstm神经网络图_nlp_55 控制保留多少来自 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_81 的旧信息,计算得 lstm神经网络光伏功率预测 lstm神经网络图_nlp_82

其中,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_83 表示按元素乘法,当 lstm神经网络光伏功率预测 lstm神经网络图_nlp_84lstm神经网络光伏功率预测 lstm神经网络图_lstm_85 时,则过去记忆元被保留并传递到当前时间步
lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_86
最后,输出门 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_56 控制采用多少来自 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_88 的长记忆,计算得 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_10

其中,lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_83 表示按元素乘法,lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_75 表示 lstm神经网络光伏功率预测 lstm神经网络图_lstm神经网络光伏功率预测_75 激活函数,当 lstm神经网络光伏功率预测 lstm神经网络图_自然语言处理_93 接近 lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_94 时,就可以将长期记忆传递给隐状态
lstm神经网络光伏功率预测 lstm神经网络图_nlp_95
上述计算过程对应的计算图如下所示:


lstm神经网络光伏功率预测 lstm神经网络图_长短期记忆网络_96


为了帮助大家进一步理解长短期记忆网络的工作方式,下面我们举一个例子来说,并给出关键代码

假设我们用长短期记忆网络对下面这个句子进行编码:我在画画

import torch
import torch.nn as nn

# 定义输入数据
# 对于输入句子我在画画,首先用独热编码得到其向量表示

x1 = torch.tensor([1, 0, 0]).float() # 我
x2 = torch.tensor([0, 1, 0]).float() # 在
x3 = torch.tensor([0, 0, 1]).float() # 画
x4 = torch.tensor([0, 0, 1]).float() # 画

h0 = torch.zeros(5) # 初始化隐状态
c0 = torch.zeros(5) # 初始化记忆元

# 定义模型参数
# 模型的输入是三维向量,这里定义模型的输出是五维向量

W_xi = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hi = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_i  = nn.Parameter(torch.randn(5)   , requires_grad = True)

W_xf = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hf = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_f  = nn.Parameter(torch.randn(5)   , requires_grad = True)

W_xo = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_ho = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_o  = nn.Parameter(torch.randn(5)   , requires_grad = True)

W_xc = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hc = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_c  = nn.Parameter(torch.randn(5)   , requires_grad = True)

# 前向传播

def forward(X, H, C):
    # 计算各种门机制
    I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i) # 输入门
    F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f) # 遗忘门
    O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o) # 输出门
    # 计算候选记忆元
    C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
    # 计算当前记忆元
    C = F * C + I * C_tilde
    # 计算当前隐状态
    H = O * C.tanh()
    # 返回结果
    return H, C

h1, c1 = forward(x1, h0, c0)
h2, c2 = forward(x2, h1, c1)
h3, c3 = forward(x3, h2, c2)
h4, c4 = forward(x4, h3, c3)

# 结果输出

print(h3) # tensor([-0.0408,  0.1785,  0.0455,  0.3802,  0.0235])
print(h4) # tensor([-0.0560,  0.1269,  0.0346,  0.3426,  0.0118])



最后提醒大家一点,如果长短期记忆网络后有接其他网络,例如后面接一个线性层做单词预测

那么通常不会用记忆元的输出,而是用隐藏层的输出



至此本文结束,要点总结如下:

  1. 循环神经网络在处理长序列时很容易会出现梯度爆炸和梯度消失的情况,导致网络难以捕捉长期依赖
    对于梯度爆炸,通常可以采用梯度裁剪解决,对于梯度消失,可以采用长短期记忆网络缓解
  2. 除了有隐状态,长短期记忆网络还增加记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动