摘要: CNN 和 RNN是当下 Deep Learning 应用领域中主流的两大结构。前篇文章中我们介绍了 CNN,本篇开始我们聊聊 RNN。RNN 跟 CNN 历史相似之处在于,都是上个世纪提出来的概念。但是由于当时计算量和数据量都比较匮乏,它 ...

人工智能学习离不开实践的验证,推荐大家可以多在FlyAI-AI竞赛服务平台多参加训练和竞赛,以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。

BERT虽好,不要忘记老朋友RNN呀

写在前面

CNN(Convolution Neural Network) 和 RNN(Recurrent Neural Network)是当下 Deep Learning 应用领域中主流的两大结构。前篇文章中我们介绍了 CNN,本篇开始我们聊聊 RNN。RNN 跟 CNN 历史相似之处在于,都是上个世纪提出来的概念。但是由于当时计算量和数据量都比较匮乏,它们都被尘封,直到近几年开始大放异彩,可以说是超越时代的产物。区别在于,CNN 在2012年就开始大行其道,而 RNN 的流行却要到2015年以后了。本文会介绍 RNN 的相关概念,并具体介绍较常见的 RNN 架构。

 

原始 RNN

CNN 这种网络架构的特点之一就是网络的状态仅依赖于输入,而 RNN 的状态不仅依赖于输入,且与网络上一时刻的状态有关。因此,经常用于处理序列相关的问题。RNN 的基础结构如下

与RNN类似的算法_自然语言处理

 

可以看出,它跟 CNN、DNN 这种 Feedforward Neural Network 结构上的区别就在于:Feedforward NN 的结构是 DAG(有向无环图),而 Recurrent NN 的结构中至少有一个环。我们假设[公式]的状态转移发生在时间维度上,则上图可以展开成以下形式:

与RNN类似的算法_深度学习_02

 

与RNN类似的算法_自然语言处理_03

RNN与 BPTT

RNN 的训练跟 CNN、DNN 本质一样,依然是 BP。但它的 BP 方法名字比较高级,叫做 BPTT(Back Propagation Through Time)。

我们首先回顾一下 DNN 相关概念。DNN 的结构如下图

与RNN类似的算法_深度学习_04

 

而 DNN 的 BP 中最重要的公式如下(不再展开讲,不熟悉 BP 的同学请参考“当我们在谈论 Deep Learning:DNN 与 Backpropagation”)

与RNN类似的算法_深度学习_05

 

有了以上 DNN 的结论,接下来我们将 RNN 沿着时间展开(UNFOLD),如下图

与RNN类似的算法_与RNN类似的算法_06

 

与RNN类似的算法_深度学习_07

上述公式如果用更形象的方式来描述,可以参考下面这张李宏毅老师的 PPT

与RNN类似的算法_rnn_08

 

与RNN类似的算法_自然语言处理_09

RNN 与 Gradient Vanish / Gradient Explode

上面 RNN 的 BPTT 公式跟 DNN 的 BP 非常相似,所以毫无疑问同样会面临 Gradient Vanish 和 Gradient Explode 的问题。这里主要有两点原因:

Activation Function

与RNN类似的算法_与RNN类似的算法_10

与RNN类似的算法_自然语言处理_11

与RNN类似的算法_rnn_12

 

而解决 Gradient Vanish 和 Gradient Explode 的方法则有:

对于 Gradient Vanish,传统的方法也有效,比如换 Activation Function 等;不过一个更好的架构能更显著的缓解这个问题,比如下面会介绍的 LSTM、GRU

对于 Gradient Explode,一般处理方法就是将梯度限制在一定范围内,即 Gradient Clipping。可以是通过阈值,也可以做动态的放缩

BRNN

BRNN(Bi-directional RNN)由 Schuster 在"Bidirectional recurrent neural networks, 1997"中提出,是单向 RNN 的一种扩展形式。普通 RNN 只关注上文,而 BRNN 则同时关注上下文,能够利用更多的信息进行预测。

结构上, BRNN 由两个方向相反的 RNN 构成,这两个 RNN 连接着同一个输出层。这就达到了上述的同时关注上下文的目的。其具体结构图如下

与RNN类似的算法_与RNN类似的算法_13

 

BRNN 与普通 RNN 本质一样,仅在训练的步骤等细节上略有差别,这里不再详解描述。有兴趣的同学可以参考原文。

LSTM

为了解决 Gradient Vanish 的问题,Hochreiter&Schmidhuber 在论文“Long short-term memory, 1997”中提出了 LSTM(Long Short-Term Memory)。原始的 LSTM 只有 Input Gate、Output Gate。而咱们现在常说的 LSTM 还有 Forget Gate,是由 Gers 在"Learning to Forget: Continual Prediction with LSTM, 2000"中提出的改进版本。后来,在"LSTM Recurrent Networks Learn Simple Context Free and Context Sensitive Languages, 2001"中 Gers 又加入了 Peephole Connection 的概念。同时,现在常用的深度学习框架 Tensorflow、Pytorch 等在实现 LSTM 上也有一些细微的区别。以上所说的虽然本质都是 LSTM,但结构上还是有所区别,在使用时需要注意一下。

下文介绍的 LSTM 是"Traditional LSTM with Forget Gates"版本。

Traditional LSTM with Forget Gates

LSTM 其实就是将 RNN 中 Hidden Layer 的一个神经元,用一个更加复杂的结构替换,称为 Memory Block。单个 Memory Block 的结构如下(图中的虚线为 Peephole Connection,忽略即可)

与RNN类似的算法_与RNN类似的算法_14

 

与RNN类似的算法_rnn_15

 

与RNN类似的算法_与RNN类似的算法_16

 

LSTM 与 Gradient Vanish

上面说到,LSTM 是为了解决 RNN 的 Gradient Vanish 的问题所提出的。关于 RNN 为什么会出现 Gradient Vanish,上面已经介绍的比较清楚了,本质原因就是因为矩阵高次幂导致的。下面简要解释一下为什么 LSTM 能有效避免 Gradient Vanish。

与RNN类似的算法_rnn_17

LSTM 与 BPTT

最初 LSTM 被提出时,其训练的方式为“Truncated BPTT”。大致的意思为,只有 Cell 的状态会 BP 多次,而其他部分的梯度会被截断,不 BP 到上一个时刻的 Memory Block。当然,这种方法现在也不使用了,所以仅此一提。

在"Framewise phoneme classification with bidirectional LSTM and other neural network architectures, 2005"中,作者提出了 Full Gradient BPTT 来训练 LSTM,也就是标准的 BPTT。这也是如今具有自动求导功能的开源框架们使用的方法。关于 LSTM 的 Full Gradient BPTT,我并没有推导过具体公式,有兴趣的同学可以参考 RNN 中 UNFOLD 的思想来试一试,这里也不再赘述了。

 

GRU

GRU(Gated Recurrent Unit) 是由 K.Cho 在"Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation, 2014"中提出的。它是 LSTM 的简化版本,但在大多数任务中其表现与 LSTM 不相伯仲,因此也成为了常用的 RNN 算法之一。

GRU 的具体结构与对应的公式如下:

与RNN类似的算法_rnn_18

 

与RNN类似的算法_rnn_19

LSTM 有三个 Gate,而 GRU 仅两个

GRU 没有 LSTM 中的 Cell,而是直接计算输出

GRU 中的 Update Gate 类似于 LSTM 中 Input Gate 和 Forget Gate 的融合;而观察它们结构中与上一时刻相连的 Gate,就能看出 LSTM 中的 Forget Gate 其实分裂成了 GRU 中的 Update Gate 和 Reset Gate

很多实验都表明 GRU 跟 LSTM 的效果差不多,而 GRU 有更少的参数,因此相对容易训练且过拟合的问题要轻一点,在训练数据较少时可以试试。

 

尾巴

除了文中提到的几种架构,RNN 还有其它一些变化。但总体而言 RNN 架构的演进暂时要逊色于 CNN,暂时常用的主要是 LSTM 和 GRU。同样,也是由于 RNN 可讲的比 CNN 少些,本次就只用一篇文章来介绍 RNN,内容上进行了压缩。但是这不代表 RNN 简单,相反不论是理论还是应用上,使用 RNN 的难度都要比 CNN 大不少。