1 简介
长短期记忆网络(Long Short-Term Memory)作为RNN的进阶架构,在序列建模领域具有里程碑意义。其核心突破在于通过智能门控系统,有效捕获跨越数百个时间步的语义关联,成功缓解了传统RNN存在的梯度消失/爆炸难题,在语音识别、金融预测等需要长程记忆的场景中表现卓越。
结构更复杂,核心结构可分四部分:
2 LSTM内部结构图

结构解释图:

2.1 遗忘门:智能记忆过滤器
结构图和计算公式

结构分析
类似传统RNN内部结构计算:
- 先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接,得到[x(t), h(t-1)]
- 再通过一个全连接层做变换,最后通过sigmoid函数进行激活得到f(t),可将f(t)看作门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量。遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t), h(t-1)计算得来,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息
动态决定历史信息的保留比例,通过sigmoid函数输出0-1之间的遗忘系数。实际应用场景如在语言模型中自动遗忘不相关的主语信息。
过程演示

激活函数sigmiod
帮助调节流经网络的值,sigmoid函数将值压缩在0和1之间。

2.2 输入门:新知融合系统
结构图与计算公式

结构分析
输入门的计算公式有两个:
- 产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤
- 与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态
可实现新信息的选择性记忆,如在股票预测中精准捕捉突发市场信号。
过程演示

2.3 细胞状态更新
结构图和计算公式

结构分析
没有全连接层,只是将刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果。最终得到更新后的C(t)作为下一个时间步输入的一部分。整个细胞状态更新过程就是对遗忘门和输入门的应用。
可构建动态记忆高速公路,如医疗诊断场景中持续更新患者病史特征。
过程演示

2.4 输出门:信息蒸馏器
结构图和计算公式

结构分析
公式也两个:
- 第一个计算输出门的门值,和遗忘门、输入门计算方式相同
- 第二个是用这个门值产生隐含状态h(t),他将作用在更新后的细胞状态C(t)上,并做tanh激活,最终得到h(t)作为下一时间步输入的一部分。整个输出门的过程,就是为产生隐含状态h(t)
可智能生成当前时刻的特征表达,如在机器翻译中精准输出目标语言词汇。
过程演示

3 Bi-LSTM
双向LSTM,未改变LSTM本身任何的内部结构,只是将LSTM应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出。

3.1 结构分析
图中对"我爱中国"这句话或叫这个输入序列,进行从左到右、从右到左两次LSTM处理,将得到的结果张量拼接作为最终输出。
这种结构能捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但模型参数和计算复杂度也随之增加一倍,一般需对语料和计算资源进行评估后,决定是否使用该结构。
可使用于:
- 正向LSTM捕捉历史依赖
- 反向LSTM捕获未来特征
如医疗文本分析中同时考虑症状描述和诊断结果。
3.2 单向LSTM V.S 双向LSTM
特性 | 单向LSTM | 双向LSTM |
参数数量 | 1x | 2x |
上下文感知 | 前向 | 全向 |
计算效率 | 高 | 中等 |
4 Pytorch中LSTM工具的使用
torch.nn工具包之中, 通过torch.nn.LSTM可调用。
4.1 nn.LSTM类初始化参数
- input_size: 输入张量x中特征维度的大小
- hidden_size: 隐层张量h中特征维度的大小
- num_layers: 隐含层的数量
- bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用
4.2 nn.LSTM类实例化对象参数
- input: 输入张量x
- h0: 初始化的隐层张量h
- c0: 初始化的细胞状态张量c
4.3 nn.LSTM使用示例
# 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)
>>> import torch.nn as nn
>>> import torch
>>> rnn = nn.LSTM(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> c0 = torch.randn(2, 3, 6)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152],
[ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477],
[ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]],
[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161],
[ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626],
[ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]],
[[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828],
[ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983],
[-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
grad_fn=<StackBackward>)5 LSTM评价
5.1 优势
LSTM的门结构能够有效减缓长序列问题中可能出现的梯度消失或爆炸, 虽然并不能杜绝这种现象, 但在更长的序列问题上表现优于传统RNN.
5.2 缺点
由于内部结构相对较复杂, 因此训练效率在同等算力下较传统RNN低很多。
本文已收录在Github,关注我,紧跟本系列专栏文章,咱们下篇再续!
- 🚀 魔都架构师 | 全网30W+技术追随者
- 🔧 大厂分布式系统/数据中台实战专家
- 🏆 主导交易系统亿级流量调优 & 车联网平台架构
- 🧠 AIGC应用开发先行者 | 区块链落地实践者
- 🌍 以技术驱动创新,我们的征途是改变世界!
- 👉 实战干货:编程严选网
















