• 摘要:transformer 具有学习长期依赖的潜力,但在语言建模设置中受到固定context length的限制。我们提出了 Transformer XL,可以在不破坏时间一致性的前提下扩展context length。它由一种段级递归机制(segment-level recurrence mechanism)和一种新的位置编码方案组成。我们的方法不仅能够捕获较长期的上下文依赖关系,而且可以解决上下文碎片化问题。Transformer XL 学习的上下文长度比 RNN 长80%,比普通 Transformer 长450%,在短序列和长序列上都取得了更好的性能,而且在评估过程中比普通 Transformer 快1800+倍… (关于性能的部分省略)

注:本文中的 Transformer 其实特指 Transformer-decoder,也就是 GPT 模型


文章目录

  • 1 传统模型的做法 & 问题
  • 2 本文方法
  • 2.1 片段递归(Segment-Level Recurrence)
  • 2.2 相对位置编码(Relative Positional Encodings)
  • 2.3 一个 trick
  • 3. 补充


1 传统模型的做法 & 问题

  • 传统 Transformer 模型能处理的序列长度是固定的,由 attention 层的尺寸决定,必须将序列数据调整为此固定长度才能输入模型。其训练和推断过程一般如下图所示
  1. 训练时,若序列数据长度比固定长度短,则通过 padding 方式补全;若序列数据长度比固定长度长,通常将长序列划分为多个具有固定长度 segments,训练时仅在各个 segment 内部计算 attention,而 segments之间没有联系。
  2. 推断时,每个 step 对具有固定长度的 segment 进行计算,预测出下一个 token 后,将 segment 范围整体右移进行 AutoRegress
  • 传统方法有如下问题
  1. 上下文长度受限模型能够建模的 max context length 被限制为此固定长度,这会影响推理性能
  2. 上下文碎片问题:出于效率的考虑,划分 segments 时没有考虑句子的自然边界,导致分割出来的 segments 在语义上是不完整的,这在一定程度程度上会误导模型
  3. 推理速度慢:Transformer decoder 计算的是 Masked self attention,也就是每个 token 只以自身产生 query,和自己之前的所有 token 生成的 key 计算 attention value 并汇聚信息,AutoRegress 过程右移 segments 的操作不会影响之前的信息汇聚结果,因此推断 AutoRegress 过程中任意相邻两步重叠的那些 hidden value(重叠的部分黄色点)应当是不需要重新计算的。由于传统模型的固定长度限制,AutoRegress 过程的每次右移都会导致模型忽略最早的一个 token,不得不对整个 segments 的所有 hidden value 进行重新计算。模型支持的 context length 越长,堆叠的 Transformer Block 越多,这种重复计算就越多,这会大大降低测试效率

2 本文方法

2.1 片段递归(Segment-Level Recurrence)

  • 为了解决上述问题,Transformer-XL 提出可以在计算当前 segment 时,缓存并利用上一个 segment 中所有 layer 的 hidden state 序列,而且上一个segment的所有隐向量序列只参与前向计算,不再进行反向传播,这就是所谓的segment-level Recurrence。如下图所示
  1. 训练阶段:每个 segment 长度保持为固定长度 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl,但是计算 New segment 时可以通过绿线访问之前缓存的 hidden state value(不回传梯度),这缓解了上下文碎片问题
  • 注意示意绿线维持了每次输入模型的序列长度为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_02
  • 理论上,需要在每一步计算时仅将相应的 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_03。这种实现方法虽然效果相同,但是没有达成节省显存的效果
  • 由于以上原因,有些借用 Transformer-XL 作为 backbone 的方法中训练阶段不进行缓存,仅在测试时缓存(如 gato 和 DB1)以扩展等效上下文长度并加速推断过程
  1. 推断阶段:仍然是每个 step 右移一位做 AutoRegress,输入序列长度也仍然是模型的固定长度 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl,区别在于计算 attention 时仅由上一步 AutoRegress 生成的 token 来产生 query,缓存的前驱 hidden state value 仅生成 key 不产生 query,从而避免了重叠 hidden value 计算。另外缓存机制还变相扩展了有效序列长度,如图所示,最后一个 Transformer Block 对应的 context length 为固定长度 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_05,每往前一个 block 扩展 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_06
  • 微观上看,提升计算效率的根本原因在于生成 query 的 token 数量减少了,这意味着计算 attention & 汇聚 value 的操作减少了。传统模型的每一层 Transformer Block 有 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_02 个 query,这样才能生成 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_02
  • 设 Transformer Block 层数为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_09,简单 AutoRegress 得到的等价上下文长度可达(图中绘制了 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_10 的情况)论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_11
  • 由于使用了特殊的位置编码(详见下文1.2.2节),在固定长 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_02
  • 下面给出片段递归的形式化表述:设 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_13 是相邻的两个 segment,模型输入序列长度为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_14,包含 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_15 层 Transformer Block,每个 hidden value 维度为 d,将 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_16 中第 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_17 层 hidden node value 记为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_语言模型_18,则 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_语言模型_19 中第 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_17 层 hidden node value 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_21 如下计算
    论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_22 其中 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_23 是 stop-gradient,表示不再对 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_16 的隐向量做梯度回传,论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_语言模型_25 是向量拼接符号。

这里 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_语言模型_26,从而 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_27, 于是 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_28 长度都是 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_29论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_30 长度为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl,输出 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_32 长度为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl。这表达的是缓存整个前驱 segment 来生成额外的 key 和 value,和上文示意图、训练阶段分析及作者代码均不同,只是核心思想的示意

2.2 相对位置编码(Relative Positional Encodings)

  • 仍考虑上面图中绘制的模型输入长度 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_34 的情景,传统方法中相邻两个 segment 的绝对位置编码为 0 1 2 3 0 1 2 3
  1. 如果维持位置编码不变,当通过绿线(缓存)构造跨 segment 序列时会出现重复的绝对位置编码,误导模型
  2. 如果每次构造好长 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl
  • 为了构造一致的位置编码,作者提出以 “当前要预测的 token(即生成 query 的 token)位置” 为原点构造相对位置编码,这种情况下无论哪个长为 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_34 的序列的位置编码都是 3 2 1 0,代表与 “当前要预测的 token 位置” 的相对距离。这种编码方案可以自然地不断向前扩展,保证时序一致性。具体的,这种位置编码具有以下特点
  1. 任意 attention head,无论 token 处于什么位置,生成的 query 向量都应该一致
  2. 任意两个 token 之间,只要间距相同,则相对位置信息相同
  3. 作者将 token 的内容信息和位置信息拆分,即 key/query 的内容信息/位置信息是在四个不同的空间中产生的,这有助于学习更好的特征
  4. relative pos embedding 是由正余弦公式生成的(类似BERT),这样模型就能学到关于位置嵌入的归纳偏差,结合 3 可以实现序列长度泛化
  • 下面给出相对位置编码的形式化描述。先看传统模型的绝对位置编码是如何在 attention 计算时发挥作用的

    论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_37 其中 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_38 是要计算 attention 的两个 token 的位置索引,论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_39 分别代表 token content embedding 和 absolute pos embedding,二者相加得到 token embedding,再分别由投影矩阵 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_40 生成 query vector 和 key vector,最后做向量内积得到 attention score 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_41,注意其中 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_42 通常是可学习的。这个计算过程可以展开为四个部分,接下来下面我们按以上四点对展开式进行处理
  1. 第一步把 learnable absolute pos embedding 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_43 换成 unlearnable relative pos embedding 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_44,注意这是正余弦公式生成的,不可学习
  2. 第二步正常展开
  3. 第三步区分 token content embedding 和 relative pos embedding 的 key 投影矩阵,得到 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer xl_45论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_46
  4. 第四步将所有 query 统一为同一个可学习向量,并对 key 的 content 和 position 加以区分,得到可学习的 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_47
  5. 最后我们还可以整理一下得到
    论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_缓存_48

2.3 一个 trick

  • GLU Variants Improve Transformer 这篇 2020 年的文章分析了对 transformer 模型中 FFD 层的诸多改进,在后续 Transformer-based 模型中得到的广泛使用,这里也提一嘴。
  • 标准的 FFD 层是一个两层 MLP
    论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_49 其中 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_50论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_语言模型_51 为激活函数,通常是 ReLU。GLU 这篇文章研究了此两层 MLP 的诸多变体,基本思想是把隐空间投影矩阵 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_深度学习_52 变成两个,再把两个隐变量分别激活后复合得到输出,提升 FFD 层的容量。形式化表示如下
    论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_53 其中 论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_54论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_55 表示按对应位置元素相乘,论文速览【序列模型】—— 【Transformer-XL】Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context_transformer_56
  • 下图整理了各种变体的形式和性能表现

    可见其中 GEGLU 和 SwiGLU 是表现比较好的

3. 补充