文章目录

  • RNN
  • LSTM
  • LSTM架构
  • Learn Gate学习门
  • Forget Gate 遗忘门
  • Remember Gate记忆门
  • Use Gate 应用门
  • 其他架构
  • Gated Recurrent Unit(GRU)
  • Peephole Connections



本节我们学习recurrent neural networks (RNNs)和 long short-term memory (LSTM)

RNN

假设有一个普通的神经网络,输入一张图片,预测出来是狗

pytorch 行人检测 pytorch rl_pytorch 行人检测

但是如果这张图片真的是一只狼呢?神经网络怎么才可以知道呢?假如我们在看自然的电视节目,前面出现了熊,出现了狐狸,这个时候我们就猜测最后一张图片是一头狼,因为都是野外的动物嘛。这种情况下,就是前面的信息对后面的预测有帮助。

pytorch 行人检测 pytorch rl_激活函数_02

这基本上就是RNN的原理

LSTM

但是RNN也存在问题,假如最近出现的两张图片分别是树和松鼠,比较远的地方出现的是熊,那么最后一张图片很可能预测是狗,因为熊出现的地方太远了,这中间经过了很多次sigmoid,信息被稀释掉了,影响就很小了。而且更甚的是,一路回去做反向传播,会出现梯度消失的问题

pytorch 行人检测 pytorch rl_pytorch 行人检测_03

这就是RNN的问题,并不擅长记忆长期内容,只能记忆短期记忆

LSTM就可以解决这个问题,因为它不仅记忆短期的记忆,还有长期记忆

pytorch 行人检测 pytorch rl_pytorch 行人检测_04

比方说再刚刚的案例中,长期记忆就是“这是关于自然科学的”,“有很多森林动物”,短期记忆就是“松鼠”“树”,有一个事件是判别是“狗还是狼”,这三个东西一起再用来生成三个东西,分别是output,更新的长期记忆,更新的短期记忆(下图的紫色箭头不仅都指向短期记忆,也都指向长期记忆和output)

pytorch 行人检测 pytorch rl_神经网络_05

在LSTM中,中间的过程是有很多“门”来控制的,分别如下图四个颜色所示

pytorch 行人检测 pytorch rl_工作原理_06

其中长期记忆我们用大象表示,它会进入到forget gate,这个gate决定要忘记哪些东西。

短期记忆用金鱼表示,金鱼和狼都会进入到learn gate来学习。

pytorch 行人检测 pytorch rl_pytorch 行人检测_07

然后没有遗忘的大象,学习好的金鱼和学习好的狼就会进入到remember gate,形成新的长期记忆。

而use gate也会用这些信息形成新的短期记忆,输出output,也就是狼的预测,和新的短期记忆

pytorch 行人检测 pytorch rl_pytorch 行人检测_08

以上的过程不断地进行,就形成了下图,t代表时间

pytorch 行人检测 pytorch rl_pytorch 行人检测_09

LSTM架构

先来看一下RNN的架构,就是利用short term memory STM和Event E一起,乘上weights,加上bias,然后用激活函数得到新的memory

pytorch 行人检测 pytorch rl_神经网络_10

LSTM的结构就是类似的

pytorch 行人检测 pytorch rl_神经网络_11

Learn Gate学习门

Learn gate是接收短期记忆和event,将两者结合起来,然后忽略一部分,只保留重要的部分

pytorch 行人检测 pytorch rl_激活函数_12

这里就是忽略了“树”,只保留了动物的部分

pytorch 行人检测 pytorch rl_工作原理_13

数学表示是这样的:

①短期记忆STM和事件E进来之后,乘上weights,加上bias,然后用tanh激活函数,就产生了新的信息Nt。这部分用的就是conbine的部分。

②然后需要再忽略一部分,那就乘上遗忘因子it。it是一个向量,进行element wise乘法。it的计算依旧需要用到前面的STM和E的信息,这里又有一个小的神经网络,用到新的权重Wi和偏差bi。

如下图所示,这就是学习门的工作原理

pytorch 行人检测 pytorch rl_激活函数_14

Forget Gate 遗忘门

这里遗忘门的目的就是接受long term memory,which 包含了自然和科学,我们需要忘记科学,留下自然。

pytorch 行人检测 pytorch rl_工作原理_15

工作原理和上面差不多,接收的是长期记忆LTM,对长期记忆进行遗忘,遗忘的过程就是乘ft,ft和it一样,计算需要用到前面的STM和E的信息,经过一层小小的神经网络就得到了ft。

pytorch 行人检测 pytorch rl_pytorch 行人检测_16

Remember Gate记忆门

这个门就更简单了,就是接收长期记忆和短期记忆,当然是处理过的长期记忆和短期记忆,就是把forget gate和learn gate的内容加起来就好了

pytorch 行人检测 pytorch rl_激活函数_17

Use Gate 应用门

这个门就是为了得到输出,也是短期记忆,这两个是一个东西。

比方说在这里,就是将long term memory中的东西里面找到一只熊,然后从short term memory中找到一只松鼠,然后得出“你的图片最有可能是一只狼,当然也涉及其他动物”

pytorch 行人检测 pytorch rl_pytorch 行人检测_18

工作原理就是将遗忘门的输出结果放到一个小型神经网络里面,使用tanh激活函数;然后把短期记忆和事件放到另一个小型神经网络里面,使用sigmoid激活函数;最后一步就是把两者相乘,得到新的输出结果。

pytorch 行人检测 pytorch rl_激活函数_19

其他架构

LSTM的结构我们已经知道了,但是为什么有的地方用tanh,有的地方用sigmoid,为什么有的地方要+,有的地方要×。其实就是因为试验下来这样有用。所以用了这样的结构。其实还有很多其他可行的结构,下面就介绍一些

Gated Recurrent Unit(GRU)

GRU把遗忘门和学习门合并为更新门 update gate,更新门的结果交给合并门 combine gate来处理。他只会翻出一个工作记忆,而不是一对长期记忆和短期记忆。

pytorch 行人检测 pytorch rl_神经网络_20

下面列举了 GRUs 的一些参考文献

Peephole Connections

这里是另一种结构,叫做窥视孔连接

回忆一下遗忘门的结构,短期以及和事件一起来决定该以往什么,也就是ft的产生,那么为什么长期记忆不来决定哪些内容该遗忘呢?所以peephole连接就加入了长期记忆,也来做决策。这样ft的生成就需要更大的数据。

pytorch 行人检测 pytorch rl_工作原理_21

可以把LSTM的所有遗忘门的部分都换成窥视孔连接,就成了具备窥视孔连接的LSTM

pytorch 行人检测 pytorch rl_工作原理_22


代码:

git clone https://github.com/udacity/deep-learning-v2-pytorch.git

转到
recurrent-neural-networks > time-series