1. RNN 存在的理由
循环神经网络(Recurrent Neural Network, RNN)是一种处理时序型输入的神经网络。它被广泛应用在语音识别、机器翻译、人名识别、文本生成等任务上。这些任务要处理的都是时序型的数据,一言以蔽之,这些时序型数据有着 输入是不定长度的,输入的上下文是由关联的 的特征。
经典的深度神经网络(NN)通过构建多层全连接的神经网络来对固定的多维输入进行预测;卷积神经网络(CNN)通过设计不同的卷积和池化层组合来对图像(网格化输入)进行预测。但在实际应用中,这两种神经网络均很难处理时序型数据,因为:
- 它们的网络结构的输入维度是固定的,不能处理不同长度的输入;
- 它们的网络结构忽略了输入节点间的横向联系,如上下文关系等。
这两个缺点恰好妨碍了 NN 和 CNN 对于时序型数据的处理。于是,一种更新的能够处理不同维度输入的、且能利用前后顺序关系的网络结构就问世了,它就是循环神经网络 RNN。换句话说,RNN 存在的理由就是为了处理时序型数据。
2. RNN 的基本结构
以机器翻译为例,句子中单词的翻译跟语境十分相关,也可以认为要翻译的单词含义与之前出现的单词十分相关,要想预测当前词的含义就必须 “记住” 之前出现过的词。我们以一个简单的语句 “John likes to drink the spring water.” 为例。通过上下文(如 “drink”),人类马上就知道这里的 “spring” 指的是 “泉水” 的意思(而非 “春天”),但是要让机器理解上下文则相对困难。
我们将每个单词视作一个输入,即这 7 个单词分别对应于 ,同时,记句子的输入长度为 ,即 。那么这 7 个单词的输出分别对应于 ,输出的总长度为 ,实际情况下允许 。同时,每个输入都有一个实时状态,记为 。这个实时状态其实就是 RNN 要记下来的前文知识。
上图展示了 RNN 两种不同的表达方式,图 (a) 常见于各种论文中,理解起来较难,将其横向展开就得到了(b);图 (b) 中的模式更适合大部分人的理解,它表示了 RNN 中信息的流动方式。
给定初始状态 ,RNN 将依次从左往右传播,且当前状态 只与当前输入 和前个时刻的状态 相关。如果当前 RNN 节点有输出值 ,则该输出值与当前的状态
这里值得一提的是,在 RNN 中我们没有强制规定输入维度 一定等于输出维度 ,事实上根据处理问题的不同,二者存在多种关系,如下图所示的 many-to-one(分类), one-to-many(文本生成), many-to-many (机器翻译,摘要生成)等不同模式。
2.1 RNN 的前向传播方式 - Forward
下图展示了 RNN 第 t 时刻的节点的输入输出情况。为了更好的计算当前状态值,RNN 设定全局共享的参数 和 (即所有的 RNN 节点参数值均相同),通过激活函数 来计算出第 t 时刻的状态值 ,通常这里的 函数会被设置为 ReLU 或者 Tanh 函数。另外,计算 时使用的 函数根据预测任务而定,例如在分类情况下使用 Softmax 函数。
实际的计算中为了节省计算量,在函数 中,我们将 和 做一个垂直方向的拼接,这样参数 和 就可以整合成一个参数 来计算了。举个简单的例子,假设我们的输入 都是 1000 维的,状态 都是 100 维,那么 应该是 100x100 的尺寸, 则是 100x1000 的尺寸,拼接起来的 则是 100x1100 尺寸的。先假设 的尺寸是 64x100 ,那么最终的
2.2 RNN 的反向传播方式 - Backward
RNN 与之前标准神经网络类似,依旧使用链式法则来反向更新参数。不同的是由于 RNN 每个节点都公用参数 ,因此求参数偏导时需要累加之前的偏导数的值,RNN 的更新方式也被称为 BPTT 算法。
为了书写方便,我们将上一小节的 函数暂时省略。同时确定 RNN 的损失函数
其中 表示样本在时刻 的实际输出, 表示样本在时刻 的预期输出。当
通过上述两式可知,当我们使用 BPTT 算法对 或 求导时,不但要考虑 时刻的情况,还要考虑 时刻以前的状况,因为 或 影响了前面时刻的状态 。于是乎,求这2个参数的偏导的过程是一个累加的过程。如下图的红色箭头所指。
首先,对 求偏导。由于 只与
然后,对 求偏导。由于 是关于 个输出
上式可化简为,
更加细致的推导请参考博客1,2。同理,我们也可以利用链式求导法则求出
综上可知,求解 和 的梯度是复杂的累积之后,再累加的一个过程,其中连续的累积部分()会造成“梯度爆炸”和“梯度消失”问题。因为 RNN 使用的是 Tanh 函数为 函数,因此
2.3 RNN 到底有没有梯度消失?
“RNN 有梯度消失,LSTM解决了它” 可能是对 RNN 或者 LSTM 最经典的误解3。事实上,RNN 的 “梯度消失” 和传统的 NN 的 “梯度消失” 含义不同:
- 在传统的 NN 网络中,各层之间的参数
- 在 RNN 中,参数
综上,虽然随着层数的增加,越远离输出层的梯度会减少,但是总的梯度和是不会消失的。因此 RNN 中不存在传统的梯度消失问题的!。因此,RNN 的弱点并不是梯度为趋近于0或消失,而是“健忘”,记不住较远距离对其的影响。
3. RNN 变体 - 长短记忆网络 LSTM
为了解决 RNN 训练过程中的 “健忘” 的问题,人们提出了 长短期记忆网络(Long short-term memory, LSTM4)的网络结构。LSTM 是 RNN 的一种变体,相比传统的 RNN 网络结构,LSTM 能够在更长的序列中有更好的表现。RNN 与 LSTM 的结构对比如下,可见 LSTM 肉眼可见的复杂了,每个节点有 2 个输入和 2 个输出。下图中的 指的就是 时刻的输出, 指的是 时刻的输入, 是 时刻的状态。
LSTM 结构最大的创新点在于它引入了 “门” 这个网络结构。具体来讲,“门” 结构相当于一个阈值控制机关,它控制了当前信息有多少比例可以通过。LSTM 设置有 “遗忘门”,“更新门”,“输出门” 3 个门结构,他们分别控制了之前状态,当前输入和当前状态有多少信息被保留下来。
遗忘门 Forget Gate: 遗忘门决定了前一时刻状态 有多少信息保留到当前状态 。遗忘门的输入是前一时刻输出 和当前输入 ,其输出 表示应该保留的比例, 符号表示 Sigmoid 函数,其输出是 0 至 1 间的实数,数值越大表示信息保留的越多。
输入门 Input Gate: 输入门决定当前输入 有多少信息输入到 状态中。这个步骤分为两步,首先通过 和 生成 和 ,然后与前一时刻状态 保留的信息相加,从而得到当前时刻的状态 。注意这里的 函数的输出范围是 0 到 1,而 函数的输出范围是 -1 到 1。因此, 是一个0-1的实数,而 则表示的是向量。
输出门 Output Gate: 输出门决定了当前时刻的输入 有多少比例影响当前时刻的输出 。由图可知 的值不但跟 和 有关,还与当前时刻的状态 有关。
根据上述的几个步骤,可以清晰的了解到 LSTM 节点(或细胞)中的信息流向。换个角度看,标准的 RNN 中,状态与输入是相同的,即 ,而 LSTM 则是分开计算的;其实在实际应用中,人们往往会直接使用 LSTM 而非标准的 RNN。
3.2 LSTM 如何解决梯度消失问题?
回到本小节开头部分,LSTM 的提出是为了解决传统 RNN 存在的 “健忘” 的问题。BPTT 算法中那一段累积部分 就是罪魁祸首。根据 LSTM 的网络结构,该段状态
关于输出
(注:本小节的图片均来自 Christopher Olah 的博文 Understanding LSTM Networks. )
4. RNN 变体 - 双向循环神经网络 BRNN
双向循环神经网络(Biodirectional Recurrent Neural Network,BRNN)也是 RNN 的一种变体,与传统 RNN 相比,BRNN 增加了后续时刻的输入对当前状态的影响。设想在人名识别的任务重有下属两个句子,
He said, “Teddy bears are on sale!”
He said, “Teddy Roosevelt was not a good President!”
(注:Teddy Roosevelt 西奥多·罗斯福(大罗斯福)是美国第 26 任总统。他的侄子 富兰克林·罗斯福 (小罗斯福)是美国第 32 任总统,并且是二战中重要的同盟国领袖之一,他成功连任 4 次美国总统!)
两个语句的开头都是一样的,但很明显第一个句子中的 Teddy 指的是泰迪熊而非人名,第二个句子的 Teddy 才指的是名字。由于 2 个句子开头都是 He said,故使用标准 RNN 则很难判断出来谁是人名。造成 RNN 瓶颈的原因正式 标准 RNN 不能够联系后文的内容。
为了解决 RNN 的这个瓶颈,人们于是提出了双向 RNN 网络,即 BRNN。在 BRNN 中 影响单个时刻输出的不但有之前的内容,并且也有后续的内容。一个标准 BRNN 的结构如下所示, 表示正向传播时的 时刻状态, 表示反向传播时 时刻的状态。
当我们计算 的输出时,我们首先通过输入 得到 1 和 2 时刻的状态 ,正向传播得到当前正向的状态 ;再结合输入 的输入得到 4 时刻的状态 ,反向传播得到当前正向的状态 。最终 是由
5. 小结
RNN 网络主要处理时序型数据,通过给不同时段的输入之间建立联系,RNN 生成最终的预测结果(many-to-one 或其他模式)。他具有一定的记忆功能,但同时记忆功能有待改进。它的后续改进版本 LSTM 注重于记忆更长时刻的知识,BRNN 注重于利用后续时刻的输入知识。当然 RNN 很多的改进版本,如 GRU 或 深层 RNN 网络。RNN 的提出给语音识别、机器翻译,文本生成等时序型数据的处理提供了基石方法,期待未来会在此基础上提出更厉害的网络模型。