深度学习之图像分类(十七)Transformer中Self-Attention以及Multi-Head Attention详解


目录

  • 深度学习之图像分类(十七)Transformer中Self-Attention以及Multi-Head Attention详解
  • 1. 前言
  • 2. Self-Attention
  • 3. Multi-head Self-Attention
  • 3. Positional Encoding


终于来到了 Transformer,从 2013 年分类网络学习到如今最火的 Transformer,真的不容易。本节学习 Transformer 中 Self-Attention 以及 Multi-Head Attention详解(注意不是 Version Transformer)。

简单的图像分类项目 上手 github 图像分类attention_深度学习

1. 前言

Transformer 是 Google 在 2017 年发表于 Computation and Language 上的,其原始论文为 Attention Is All You Need。Transformer 一开始的提出是针对 NLP 领域的。在此之前主要用 RNN 和 LSTM 等时序网络,这些时序网络他们的问题在于,RNN 的记忆长度是有限的,比较短。此外,他们无法并行化,必须先计算 简单的图像分类项目 上手 github 图像分类attention_Self_02 时刻再计算 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_03 时刻,是串行的关系,所以训练效率低。基于这些问题,Google 便提出了 Transformer 来解决这一系列问题。Transformer 在理论上不受硬件限制的话,记忆长度可以是无限长的;其次他是可以做并行化的。在这篇文章中作者提出了 Self-Attention 的概念,然后在此基础上提出 Multi-Head Attention。本节主要是对 Transformer 中的 Self-Attention 以及 Multi-Head Attention 进行讲解。

2. Self-Attention

过去我们经常看到这三张图以及对应的公式,但是还是很难理解是什么意思。李宏毅老师对此曾说:”看不懂的人,你再怎么看,还是看不懂“。

简单的图像分类项目 上手 github 图像分类attention_pytorch_04

对此我们来进一步细讲它的理论。假如我们输入的时序数据是 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_05,例如这里的 简单的图像分类项目 上手 github 图像分类attention_pytorch_06简单的图像分类项目 上手 github 图像分类attention_transformer_07。首先我们会把他们通过 Embedding 层映射到更高的维度上得到对应的 简单的图像分类项目 上手 github 图像分类attention_Self_08简单的图像分类项目 上手 github 图像分类attention_pytorch_09。紧接着将 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10 分别通过 简单的图像分类项目 上手 github 图像分类attention_transformer_11 三个参数矩阵生成对应的 简单的图像分类项目 上手 github 图像分类attention_Self_12。在网络中 简单的图像分类项目 上手 github 图像分类attention_transformer_11 三个参数矩阵是共享的。在源码中, 简单的图像分类项目 上手 github 图像分类attention_transformer_11 其实直接通过全连接层来实现的,是可训练的参数。在这里讲的时候忽略偏置方便理解。假设 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_15,然后再假设 简单的图像分类项目 上手 github 图像分类attention_pytorch_16 矩阵为 [[1, 1], [0, 1]]。根绝公式就可以得到 简单的图像分类项目 上手 github 图像分类attention_Self_17,同理 简单的图像分类项目 上手 github 图像分类attention_深度学习_18
简单的图像分类项目 上手 github 图像分类attention_深度学习_19
这里的 简单的图像分类项目 上手 github 图像分类attention_pytorch_20 表达的含义是 query。也就是接下来他(query)会去匹配每一个 key。这里的 key 也是 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10简单的图像分类项目 上手 github 图像分类attention_pytorch_22 进行相乘得到的。简单的图像分类项目 上手 github 图像分类attention_深度学习_23 则是从 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10 中提取得到的信息,他是 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10简单的图像分类项目 上手 github 图像分类attention_Self_26 进行相乘得到的,也可理解为网络认为的从 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10 中提取到的有用的信息。由于在 Transformer 中是可并行化的,所以可以按照矩阵乘法的形式进行书写。例如 简单的图像分类项目 上手 github 图像分类attention_Self_08简单的图像分类项目 上手 github 图像分类attention_pytorch_09 可以拼接到一起得到 [[1, 1], [1, 0]]。将 简单的图像分类项目 上手 github 图像分类attention_pytorch_20 全部放在一起就是 Attention 公式中的 Q,同理将 简单的图像分类项目 上手 github 图像分类attention_Self_31简单的图像分类项目 上手 github 图像分类attention_深度学习_23 分别放在一起就是公式中的 简单的图像分类项目 上手 github 图像分类attention_pytorch_33简单的图像分类项目 上手 github 图像分类attention_深度学习_34
简单的图像分类项目 上手 github 图像分类attention_transformer_35

简单的图像分类项目 上手 github 图像分类attention_Self_36

当得到 简单的图像分类项目 上手 github 图像分类attention_Self_37 之后,就需要将 简单的图像分类项目 上手 github 图像分类attention_pytorch_20 与每一个 简单的图像分类项目 上手 github 图像分类attention_Self_31 进行 match 匹配,简单的图像分类项目 上手 github 图像分类attention_transformer_40,其中 简单的图像分类项目 上手 github 图像分类attention_Self_41简单的图像分类项目 上手 github 图像分类attention_Self_31 的 dimension (简单的图像分类项目 上手 github 图像分类attention_Self_31 求出来其实是一个向量,所以其 dimension 就是向量中元素的个数,即向量的长度,在下图中为 2)。最终经过 Softmax 之后,得到的权重越大,我们就会关注对应的 简单的图像分类项目 上手 github 图像分类attention_深度学习_23
简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_45

简单的图像分类项目 上手 github 图像分类attention_深度学习_46

同样的我们也会拿 简单的图像分类项目 上手 github 图像分类attention_pytorch_47 和每个 key 进行匹配,同样可以得到 简单的图像分类项目 上手 github 图像分类attention_Self_48简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_49。经过 Softmax 就可以得到 简单的图像分类项目 上手 github 图像分类attention_深度学习_50简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_51。这个过程也是可以用矩阵乘法的形式来进行书写的,即 简单的图像分类项目 上手 github 图像分类attention_transformer_52
简单的图像分类项目 上手 github 图像分类attention_pytorch_53

简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_54

简单的图像分类项目 上手 github 图像分类attention_深度学习_55 其实就是针对每一个 简单的图像分类项目 上手 github 图像分类attention_深度学习_23 的权重大小。所以接下来使用 简单的图像分类项目 上手 github 图像分类attention_深度学习_55简单的图像分类项目 上手 github 图像分类attention_深度学习_23 进行进一步操作。即拿 简单的图像分类项目 上手 github 图像分类attention_transformer_59简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_60 相乘加上拿 简单的图像分类项目 上手 github 图像分类attention_深度学习_61简单的图像分类项目 上手 github 图像分类attention_Self_62 相乘得 简单的图像分类项目 上手 github 图像分类attention_Self_63,即拿 简单的图像分类项目 上手 github 图像分类attention_深度学习_50简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_60 相乘加上拿 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_51简单的图像分类项目 上手 github 图像分类attention_Self_62 相乘得 简单的图像分类项目 上手 github 图像分类attention_transformer_68。这个过程也是可以用矩阵乘法的形式来进行书写的。
简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_69

简单的图像分类项目 上手 github 图像分类attention_深度学习_70

此时对于 Self-Attention 这个公式基本讲解完了。总结下来就是论文中的一个公式:
简单的图像分类项目 上手 github 图像分类attention_深度学习_71
如果将其抽象为一个模块的话,可如下所示:

简单的图像分类项目 上手 github 图像分类attention_transformer_72

3. Multi-head Self-Attention

在 Transformer 使用过程中使用更多的其实还是 Multi-head Self-Attention。原论文中说使用多头注意力机制能够联合来自不同 head 部分学习到的信息。(Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.)

简单的图像分类项目 上手 github 图像分类attention_pytorch_73

Multi-head Self-Attention 其实也非常简单,首先还是拿 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_10简单的图像分类项目 上手 github 图像分类attention_transformer_11 相乘得到 简单的图像分类项目 上手 github 图像分类attention_Self_12。然后我们根据 Head 对数据进行拆分。例如 简单的图像分类项目 上手 github 图像分类attention_transformer_77,则将它拆分后得到简单的图像分类项目 上手 github 图像分类attention_Self_78简单的图像分类项目 上手 github 图像分类attention_transformer_79。在源码中就是将 简单的图像分类项目 上手 github 图像分类attention_pytorch_80 均分给每个 Head。在论文中作者说通过线性映射得到的,其实可以直接理解为按照 head 的个数直接均分即可。
$$
\text { head }{i}=\text { Attention }\left(Q W{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\right) \

W_1^Q = W_1^K = W_1^V = \left(\begin{array}{l}
1,0 \
0,1 \
0,0 \
0,0 \
\end{array}\right)
\quad
W_2^Q = W_2^K = W_2^V = \left(\begin{array}{l}
0,0 \
0,0 \
1,0 \
0,1 \
\end{array}\right)
$$

简单的图像分类项目 上手 github 图像分类attention_深度学习_81

同理将所有的 简单的图像分类项目 上手 github 图像分类attention_Self_12 进行拆分,将第二个下标为 简单的图像分类项目 上手 github 图像分类attention_深度学习_83简单的图像分类项目 上手 github 图像分类attention_Self_12 (即 简单的图像分类项目 上手 github 图像分类attention_深度学习_85 ) 分配给 head 1,将第二个下标为 简单的图像分类项目 上手 github 图像分类attention_transformer_86简单的图像分类项目 上手 github 图像分类attention_Self_12 (即 简单的图像分类项目 上手 github 图像分类attention_深度学习_88

简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_89

接下来对于每一个 Head 执行之前描述的 Self-Attention 中的一系列过程。
简单的图像分类项目 上手 github 图像分类attention_transformer_90

简单的图像分类项目 上手 github 图像分类attention_深度学习_91

然后将计算结果进行拼接即可。简单的图像分类项目 上手 github 图像分类attention_pytorch_92 (head1 得到的 简单的图像分类项目 上手 github 图像分类attention_Self_63) 和 简单的图像分类项目 上手 github 图像分类attention_深度学习_94 (head2 得到的 简单的图像分类项目 上手 github 图像分类attention_Self_63) 拼接在一起。简单的图像分类项目 上手 github 图像分类attention_Self_96 (head1 得到的 简单的图像分类项目 上手 github 图像分类attention_transformer_68) 和 简单的图像分类项目 上手 github 图像分类attention_深度学习_98 (head2 得到的 简单的图像分类项目 上手 github 图像分类attention_transformer_68) 拼接在一起。

简单的图像分类项目 上手 github 图像分类attention_Self_100

拼接后还需要通过 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_101 将拼接后的数据进行融合得到最终 MultiHead 的输出。为了保证输入输出 multi-head attention 的向量长度保持不变,简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_101 的维度是 简单的图像分类项目 上手 github 图像分类attention_transformer_103简单的图像分类项目 上手 github 图像分类attention_Self_104 其实也等于 简单的图像分类项目 上手 github 图像分类attention_transformer_105

简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_106

multi-head attention 其实和 group conv 很像啊

总结下来就是论文中的两个公式:
简单的图像分类项目 上手 github 图像分类attention_pytorch_107
如果将其抽象为一个模块的话,可如下所示:

简单的图像分类项目 上手 github 图像分类attention_深度学习_108

原论文章节3.2.2中最后有说 Self-Attention 和 Multi-Head Self-Attention 的计算量其实差不多。Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

3. Positional Encoding

假设我们输入 简单的图像分类项目 上手 github 图像分类attention_Self_109 得到对应的 简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_110 ;如果将 简单的图像分类项目 上手 github 图像分类attention_深度学习_111 的顺序进行调换,对于 简单的图像分类项目 上手 github 图像分类attention_Self_63 而言是没有任何影响的。无论后面顺序如何是不影响 简单的图像分类项目 上手 github 图像分类attention_Self_63

简单的图像分类项目 上手 github 图像分类attention_pytorch_114

下面是使用 Pytorch 做的一个实验,首先使用 nn.MultiheadAttention 创建一个 Self-Attention 模块(num_heads=1),注意这里在正向传播过程中直接传入 简单的图像分类项目 上手 github 图像分类attention_Self_37,接着创建两个顺序不同的 简单的图像分类项目 上手 github 图像分类attention_Self_37 变量 t1 和 t2(主要是将 简单的图像分类项目 上手 github 图像分类attention_Self_117简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_118 的顺序换了下)。对比结果可以发现,对于 简单的图像分类项目 上手 github 图像分类attention_Self_63 是没有影响的, 简单的图像分类项目 上手 github 图像分类attention_transformer_68简单的图像分类项目 上手 github 图像分类attention_transformer_121

import torch
import torch.nn as nn


m = nn.MultiheadAttention(embed_dim=2, num_heads=1)

t1 = [[[1., 2.],   # q1, k1, v1
       [2., 3.],   # q2, k2, v2
       [3., 4.]]]  # q3, k3, v3

t2 = [[[1., 2.],   # q1, k1, v1
       [3., 4.],   # q3, k3, v3
       [2., 3.]]]  # q2, k2, v2

q, k, v = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result1: \n", m(q, k, v))

q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2: \n", m(q, k, v))

简单的图像分类项目 上手 github 图像分类attention_简单的图像分类项目 上手 github_122

对于每一个 简单的图像分类项目 上手 github 图像分类attention_深度学习_123 会加一个 shape 一样的位置编码。可以根据论文公式进行计算得到位置编码,也可以训练得到位置编码。

简单的图像分类项目 上手 github 图像分类attention_pytorch_124