引言

预训练的BERT模型具有大量的参数,导致它们无法在边缘设备如智能手机上应用。

为了解决这个问题,我们可以基于知识蒸馏(knowledge distillation,KD)从一个大的预训练BERT迁移学习到一个小的BERT模型上。

知识蒸馏简介

知识蒸馏是一种模型压缩(model compression)技术,用于训练小模型来重现大的预训练模型的表现。也成为教师-学生学习(teacher-student learning),其中大的预训练模型就是教师(模型),而小的模型则为学生(模型)。

假设我们预训练了一个大模型来预测句子中下一个单词。我们称该大模型为教师网络。如果我们输入一个句子让该模型来预测该句子的下一个单词,那么它会返回词表中所有单词作为下一个单词的概率分布,如下图所示。这里为了简化,假设词表中只有5个单词。

几个知识蒸馏相关的BERT变体_损失函数

该概率分布主要通过应用输出层的logits到Softmax中得到,然后我们可以选择概率最大的单词作为预测的下一个单词。这里​​Homework​​​的概率最大,所以输出的下一个单词为​​Homework​​。

除了选择概率最大的单词外,我们还能从该概率分布中得到什么有用的信息吗?答案是肯定的。基于下图,我们可以看到,除了概率最大的单词外,还有一些单词的概率与其他单词相比也是较大的。即,单词​​Book​​​和​​Assignment​​​与其他单词如​​Cake​​​和​​Car​​相比,它们的概率更大。

几个知识蒸馏相关的BERT变体_TinyBERT_02

这说明,除了单词​​Homework​​​外,​​Book​​​和​​Assignment​​也和给定的句子有相关性。这称为暗知识(dark knowledge)。在知识蒸馏期间,我们希望学生能从教师中学到这些暗知识。

听起来不错,但是我们知道,一个好的模型通常对于正确类别会返回一个接近于1的概率,而对于其他类别返回接近与0的概率。确实是这样,考虑下面的例子。假设我们的模型返回了下面这样的概率分布。

几个知识蒸馏相关的BERT变体_TinyBERT_03

其中对于单词​​Homework​​返回了一个很高的概率,而其他单词返回的概率接近于0。除了正确答案之外,我们无法从该概率分布中获得更多的信息。所以,我们如何在这里抽取暗知识呢?

此时,我们可以使用带有温度的Softmax函数(与蒸馏相呼应),该温度通常成为Softmax温度。我们在输出层中使用该Softmax温度。它用于平滑概率分布,公式如下:
几个知识蒸馏相关的BERT变体_概率分布_04
在上面的公式中,几个知识蒸馏相关的BERT变体_DistilBERT_05就是温度。当几个知识蒸馏相关的BERT变体_概率分布_06时,就是标准的Softmax函数。增加几个知识蒸馏相关的BERT变体_DistilBERT_05值会使分布更平滑,同时带来更多其他类的信息。

比如,下图中,当几个知识蒸馏相关的BERT变体_概率分布_06时,我们得到了标准Softmax输出的概率分布。当几个知识蒸馏相关的BERT变体_TinyBERT_09时,输出的概率分布更加平滑,当几个知识蒸馏相关的BERT变体_TinyBERT_10时,概率分布更加平滑。所以通过增加几个知识蒸馏相关的BERT变体_DistilBERT_05值,我们可以得到一个平滑的概率分布,它可以给其他类别更多的信息。

几个知识蒸馏相关的BERT变体_编码器_12

一定程度内,增加几个知识蒸馏相关的BERT变体_DistilBERT_05不会影响输出概率的相对大小,比如​​​Homework​​​在几个知识蒸馏相关的BERT变体_TinyBERT_10时也是最高概率。但是不能无限增大,比如另几个知识蒸馏相关的BERT变体_DistilBERT_05接近无穷大,那么就没有意义了。

这样我们通过Softmax温度获取了暗知识。首先我们会预训练教师模型获得暗知识。然后,在知识蒸馏时,我们从教师模型中转移这些暗知识到学生模型。

训练学生网络

上小节中,我们了解了一个预测句子中下一个单词的预训练网络。该预训练网络就是教师网络。现在,我们来学习如何从教师网络中迁移知识到学生网络。注意学生网络不是预训练的,只有教师网络是预训练的,同时是带有Softmax温度的预训练。

正如下图,我们将输入句子喂给教师和学生网络,然后得到概率分布作为输出。我们知道教师网络是预训练的,所以它输出的概率分布就是我们的目标输出。教师网络的输出成为软目标(soft target),由学生网络做的预测成为软预测(soft prediction)。

几个知识蒸馏相关的BERT变体_损失函数_16

现在,我们计算软目标和软预测之间的交叉熵,然后训练学生网络以最小化该交叉熵损失,该损失称为蒸馏损失(distillation loss)。从下图可以看到,我们将教师和学生网络中的温度几个知识蒸馏相关的BERT变体_DistilBERT_05设为同一个大于几个知识蒸馏相关的BERT变体_概率分布_18的值。

几个知识蒸馏相关的BERT变体_DistilBERT_19

这样通过反向传播我们就可以最小化蒸馏损失来训练学生网络。除了蒸馏损失外,我们还使用另一个损失,称为学生损失(student loss)。

为了理解学生损失,我们先来理解软目标和硬目标(hard target)之间的区别。如下图所示,由教师网络返回的概率分布成称软目标,而硬目标,我们将最大概率设为1,其他单词设为0。

几个知识蒸馏相关的BERT变体_TinyBERT_20

现在,我们来理解软预测和硬预测(hard prediction)的区别。软预测是由基于大于几个知识蒸馏相关的BERT变体_概率分布_18温度的学生网络得到的概率分布,而硬预测是由基于温度几个知识蒸馏相关的BERT变体_概率分布_06

学生损失基本上就是硬目标和硬预测之间的交叉熵损失。下图可以帮助我们理解如何计算学生损失和蒸馏损失。首先,我们来看学生损失。为了计算学生损失,我们在学生网络中使用几个知识蒸馏相关的BERT变体_概率分布_06的Softmax函数,得到硬预测。而硬目标软目标中概率最大的位置设为几个知识蒸馏相关的BERT变体_概率分布_18,其他位置设为几个知识蒸馏相关的BERT变体_DistilBERT_25得到的。然后我们将学生损失计算为硬预测硬目标之间的交叉熵。

几个知识蒸馏相关的BERT变体_DistilBERT_26

为了计算蒸馏损失,我们使用大于几个知识蒸馏相关的BERT变体_概率分布_18的Softmax函数温度,我们将蒸馏损失计算为软预测软目标之间的交叉熵损失。

我们最终的损失函数是学生损失和蒸馏损失之间的加权和:
几个知识蒸馏相关的BERT变体_编码器_28
几个知识蒸馏相关的BERT变体_编码器_29几个知识蒸馏相关的BERT变体_DistilBERT_30是用于计算学生损失和蒸馏损失之间加权平均的超参数。我们通过最小化上面的损失函数来训练学生网络。

这样,在知识蒸馏中,我们把预训练的网络作为教师网络。我们通过蒸馏训练学生网络获得教师网络的知识。通过最小化上面的损失函数来训练学生网络。

DistilBERT - 蒸馏版的BERT

DistilBERT是一个更小、更快、更便宜、轻量级版本的BERT。它使用了知识蒸馏。DistilBERT的最终思想是,我们采用一个预先训练好的大型BERT模型,通过知识蒸馏将其知识转移到一个小型BERT。

大型预训练BERT称为教师BERT(teacher BERT),而小型BERT称为学生BERT(student BERT)。

DistilBERT比大型BERT快60%,同时小40%。

教师-学生结构

我们先来理解这种教师-学生结构。

教师BERT

教师BERT是一个大型预训练BERT模型。我们使用预训练的BERT-base模型作为教师。

因为BERT是使用掩码语言建模任务进行预训练的,我们可以使用预训练的BERT模型来预测掩码单词。

几个知识蒸馏相关的BERT变体_损失函数_31

上图就是BERT做掩码建模任务的过程,输入一句话,它可以输出被掩码单词属于词表中每个单词的概率分布。该概率分布包含我们需要转移到学生BERT中的暗知识。

学生BERT

与教师BERT不同,学生BERT不是预训练好的。学生BERT需要从教师BERT中学习。相比教师BERT,学生BERT包含的网络层数更少。教师BERT包含110M个参数,而学生BERT只包含66M个参数。

因为学生BERT中包含更少的网络层,与教师BERT(BERT-base)相比,它能训练得更快。

DistilBERT的作者将学生BERT隐藏状态维度设为768,与教师BERT相同。他们发现减少学生BERT的维度对于计算性能没有太大的影响。所以他们关注于减少网络层数。

训练学生BERT

我们可以使用和预训练教师BERT时一样的数据集来训练学生BERT。

这里我们从RoBERTa中借鉴一些策略,比如我们只训练掩码语言建模任务,并在该任务中,我们使用动态掩码(dynamic masking),同时我们也采用较大的批大小。

如下图所示,我们将掩码句子喂给教师BERT和学生BERT,分别得到一个基于词表的概率分布输出。接着,我们计算软目标和软预测之间的蒸馏损失和交叉熵损失。

几个知识蒸馏相关的BERT变体_DistilBERT_32

除了蒸馏损失,我们还计算学生损失,它是掩码语言建模损失,即,硬目标(真实标签)和硬预测(几个知识蒸馏相关的BERT变体_概率分布_06的标准Softmax预测)之间的交叉熵损失,如下图所示:

几个知识蒸馏相关的BERT变体_TinyBERT_34

除此之外,我们还计算余弦嵌入损失(cosine embedding loss)。它基本上是教师和学生BERT所学的表示之间的距离度量。最小化余弦嵌入损失使得学生的表示更加准确,更接近于教师的嵌入。

这样,我们最终的损失函数为下列三个损失之和:

  • 蒸馏损失
  • 掩码语言建模损失(学生损失)
  • 余弦嵌入损失

通过最小化上面三个损失之和来训练我们的学生BERT(DistilBERT)。在训练之后,我们的学生BERT会获得教师BERT的知识。

DistilBERT为我们提供了接近97%的原始BERT-Base模型的准确结果,同时推理速度快了60%。因为DistilBERT更轻量,所以我们可以很容易地将它部署在边缘设备上。

DistilBERT在8块16G V100 GPU上训练了近90个小时。预训练好的DistilBERT已经由🤗释放出来了。

TinyBERT简介

TinyBERT是另一个使用知识蒸馏技术的BERT变体。通过使用DistilBERT,我们学习了如何将知识从教师BERT的输出层转移到学生BERT。但除此之外,我们能转移教师BERT其他层的知识到学生BERT吗?是的!

在TinyBERT中,除了转移教师输出层(预测层)中的知识到学生,我们也可以转移嵌入层和编码器层的知识。

举个例子,假设我们有一个N层编码器层的教师BERT。下图描绘了预先训练的教师BERT模型,在这个模型中,我们输入一个掩码句子,它返回我们词表中所有被掩码的单词的logits。

几个知识蒸馏相关的BERT变体_DistilBERT_35

在DistilBERT中,(1)我们拿教师BERT输出层产生的logits来训练学生BERT以产生同样的logits。除此之外,现在,在TinyBERT中,我们也拿(2)教师BERT产生的隐藏状态和注意力矩阵来训练学生BERT产生同样的隐藏状态和注意力矩阵。接下来,我们也拿(3)教师BERT嵌入层的输出训练学生BERT以产生同样的嵌入。

这样,在TinyBERT中,除了转移教师BERT输出层的知识外,我们还转移中间层的知识。这有助于帮助学生BERT从教师BERT中学到更多信息。比如,注意力权重封装了语言信息。

另外,在TinyBERT中,我们使用了一个两阶段的学习框架,我们在预训练和微调阶段都应用了蒸馏。在下节中,我们将学习这两阶段学习如何帮助我们。现在我们已经对TinyBERT有了一个基本的概念和概述,让我们更详细地探索它。

教师-学生结构

为了理解 TinyBERT 到底是如何工作的,首先让我们理解一下所使用的前提和符号。下图显示了教师和学生 BERT:

几个知识蒸馏相关的BERT变体_编码器_36

首先,我们来看教师BERT,然后再看学生BERT。

理解教师BERT

从上图中我们可以看到,教师BERT包含N个编码器层。首先我们将输入句子喂给嵌入层得到输入嵌入。然后,我们将这些输入嵌入传递给编码器层。这些编码器层基于自注意机制学习输入句子的上下文信息并返回句子表示。接下来,我们把这些句子表示喂给预测层。

预测层基本上就是前馈网络。如果我们做掩码语言建模任务,那么预测层会返回词表中所有单词被预测为掩码单词的logits。

我们使用预训练的BETR-base模型作为教师BERT。我们知道BERT-base模型包含12个编码器和12个注意力头,同时产生的表示的大小(隐藏状态维度几个知识蒸馏相关的BERT变体_编码器_37)是768。教师BERT包含110M个参数。

理解学生BERT

学生BERT的架构和教师BERT相同,但只包含几个知识蒸馏相关的BERT变体_损失函数_38个编码器层。几个知识蒸馏相关的BERT变体_编码器_39是大于几个知识蒸馏相关的BERT变体_损失函数_38的。

如果我们使用4个编码器层的BERT模型作为学习BERT,同时设置表示大小(隐藏状态维度几个知识蒸馏相关的BERT变体_概率分布_41)为312。那么学习BERT只包含14.5M个参数。

那么蒸馏是如何做到的?我们如何将知识从教师BERT迁移到学生BERT。

TinyBERT中的知识蒸馏

在TinyBERT中除了从输出层(预测层)迁移知识,也会从其他层。我们看以下层中蒸馏是如何进行的:

  • Transformer层(编码器层)
  • 嵌入层(输入层)
  • 预测层(输出层)

下图显示了TinyBERT中的教师BERT和学生BERT:

几个知识蒸馏相关的BERT变体_概率分布_42

注意在教师BERT中,索引0代表嵌入层,1代表第一个编码器层,2代表第2个编码器层,N表示第N个编码器层,而N+1表示预测层。

同样地,在学生BERT中,索引0表示嵌入层,1表示第一个编码器层,2表示第2个编码器层,M表示第M个编码器层,M+1表示预测层。

通过如下形式将知识从教师迁移到学生BERT:
几个知识蒸馏相关的BERT变体_概率分布_43
上面的公式表示我们使用映射函数几个知识蒸馏相关的BERT变体_损失函数_44从教师BERT的第n层迁移知识到学生的第m层。

比如:

  • 几个知识蒸馏相关的BERT变体_概率分布_45表示从教师网络的第0层(嵌入层)迁移知识到学生网络的第0层(嵌入层)。
  • 几个知识蒸馏相关的BERT变体_损失函数_46表示从教师的第N+1层(预测层)迁移知识到学生的第M+1层(预测层)。

Transformer层蒸馏

Transformer层基本上是编码器层,我们知道在编码器层,会使用多头注意力计算注意力矩阵。然后编码器层返回隐藏状态表示作为输出。在Transformer蒸馏中,我们除了迁移教师网络的注意力矩阵,同时也迁移隐藏状态。因此,Transformer层蒸馏包含两种蒸馏:

  • 基于注意力的蒸馏
  • 基于隐藏状态的蒸馏
基于注意力的蒸馏

在基于注意力的蒸馏中,我们迁移的是注意力矩阵。但为什么要这么做呢?

注意力矩阵包含很多有用的信息,比如语言语法、指称信息等。这有助于更好地理解语言。

为了进行这种迁移,我们通过最小化学生网络和教师BERT之间的均方误差来训练学生网络。该基于注意力的损失几个知识蒸馏相关的BERT变体_概率分布_47描述如下:
几个知识蒸馏相关的BERT变体_损失函数_48
其中:

  • 几个知识蒸馏相关的BERT变体_DistilBERT_49代表注意力头数
  • 几个知识蒸馏相关的BERT变体_DistilBERT_50代表学生网络的第几个知识蒸馏相关的BERT变体_TinyBERT_51个头的注意力矩阵
  • 几个知识蒸馏相关的BERT变体_TinyBERT_52表示教师网络的第几个知识蒸馏相关的BERT变体_TinyBERT_51个头的注意力矩阵
  • 几个知识蒸馏相关的BERT变体_编码器_54就是均方误差

因此,通过最小化学生网络和教师BERT中注意力矩阵的均方误差来进行基于注意力的蒸馏。值得注意的是我们使用一个未归一化的注意力矩阵,即没有使用Softmax函数。因为为归一化的注意力矩阵在此设定中效果更好且更快收敛。该过程可以如下图示:

几个知识蒸馏相关的BERT变体_TinyBERT_55

基于隐藏状态的蒸馏

我们来看如何进行基于隐藏状态的蒸馏。隐藏状态基本就是编码器的输出,即表示。因此我们迁移的也是教师的隐藏状态。令几个知识蒸馏相关的BERT变体_TinyBERT_56表示学生的隐藏状态,几个知识蒸馏相关的BERT变体_损失函数_57表示教师的隐藏状态。那么我们通过最小化几个知识蒸馏相关的BERT变体_TinyBERT_56几个知识蒸馏相关的BERT变体_损失函数_57之间的均方误差来进行蒸馏:
几个知识蒸馏相关的BERT变体_概率分布_60
等等!学生和教师的隐藏状态维度不一定相同。假设用几个知识蒸馏相关的BERT变体_编码器_37表示教师的隐藏状态维度,几个知识蒸馏相关的BERT变体_概率分布_41表示学生的隐藏状态维度。我们知道教师BERT是BERT-base,而学生BERT是TinyBERT。因此,几个知识蒸馏相关的BERT变体_编码器_37通常是大于几个知识蒸馏相关的BERT变体_概率分布_41的。

所以,为了让它们的维度保持一致,我们需要进行一次线性变换,通过让几个知识蒸馏相关的BERT变体_TinyBERT_56乘上矩阵几个知识蒸馏相关的BERT变体_DistilBERT_66几个知识蒸馏相关的BERT变体_DistilBERT_66的值是可以学习的。那么我们可以重写损失函数为:
几个知识蒸馏相关的BERT变体_DistilBERT_68
从下图我们可以看到是如何将隐藏状态进行迁移的:

几个知识蒸馏相关的BERT变体_TinyBERT_69

嵌入层蒸馏

在嵌入层蒸馏中,我们迁移的肯定就是嵌入层知识。令几个知识蒸馏相关的BERT变体_DistilBERT_70表示学生的嵌入,几个知识蒸馏相关的BERT变体_损失函数_71表示教师的嵌入,那么我们可以通过最小化几个知识蒸馏相关的BERT变体_DistilBERT_70几个知识蒸馏相关的BERT变体_损失函数_71的均方误差来进行嵌入层蒸馏:
几个知识蒸馏相关的BERT变体_损失函数_74
同样,它们的维度很可能不同。因此我们也让学生的嵌入几个知识蒸馏相关的BERT变体_DistilBERT_70乘以矩阵几个知识蒸馏相关的BERT变体_DistilBERT_76进行线性转换。损失函数变成如下:
几个知识蒸馏相关的BERT变体_DistilBERT_77

预测层蒸馏

在预测层蒸馏中,我们迁移的就是最后的输出层知识,即教师BERT产生的logits。这与DistilBERT中的蒸馏损失类似。

通过最小化软目标和软预测之间的交叉熵损失来进行预测层蒸馏。令几个知识蒸馏相关的BERT变体_TinyBERT_78表示学生网络的logits,几个知识蒸馏相关的BERT变体_编码器_79表示教师网络的logits。那么损失函数可以表示如下:
几个知识蒸馏相关的BERT变体_DistilBERT_80
我们已经看到了TinyBERT不同层的蒸馏和不同的损失函数,那么最终的损失函数是怎样的呢?

最终的损失函数

包含所有层的损失函数描述如下:
几个知识蒸馏相关的BERT变体_TinyBERT_81
从中我们可以看到:

  • 几个知识蒸馏相关的BERT变体_编码器_82为0时,说明当前层是嵌入层所以我们使用嵌入层损失
  • 几个知识蒸馏相关的BERT变体_编码器_82在0到几个知识蒸馏相关的BERT变体_TinyBERT_84之间,说明当前层属于编码器层,所以使用隐藏状态损失和注意力层损失之和
  • 几个知识蒸馏相关的BERT变体_编码器_82几个知识蒸馏相关的BERT变体_损失函数_86,说明此时为预测层,所以使用预测层损失

最终的损失函数表示如下:
几个知识蒸馏相关的BERT变体_损失函数_87
几个知识蒸馏相关的BERT变体_概率分布_88表示第几个知识蒸馏相关的BERT变体_DistilBERT_89层的损失函数,几个知识蒸馏相关的BERT变体_DistilBERT_90作为超参数控制第几个知识蒸馏相关的BERT变体_DistilBERT_89层的重要性。

我们通过最小化上面的损失函数来训练学生BERT。

训练学生BERT(TinyBERT)

在TinyBERT中,我们使用如下两阶段学习框架:

  • 通用蒸馏
  • 任务特定蒸馏

该两阶段学习框架能在预训练和微调阶段进行蒸馏。

通用蒸馏

通用蒸馏(General distillation)基于预训练阶段。这里,我们使用大规模的预训练BERT(BERT-Base)作为教师,并迁移其知识到小的学生BERT(TinyBERT)。我们在所有层上都应用蒸馏。

我们知道教师BERT-Base模型是在通用数据集(维基百科和多伦多BookCorpus数据集)上预训练。在应用蒸馏时,我们也使用同样地数据集。

蒸馏后,我们学生BERT会包含来自教师的知识,然后我们称预训练的学生BERT为通用TinyBERT。

在通用蒸馏后,我们得到一个通用TinyBERT,就是一个预训练的学生BERT。现在我们可以为下游任务来微调这个通用TinyBERT了。

任务特定蒸馏

任务特定蒸馏基本在微调阶段。不像DistilBERT,除了在预训练阶段应用蒸馏,在TinyBERT中,我们也在微调阶段应用蒸馏。

首先,我们拿一个预训练的BERT-Base模型然后为特定任务微调它,然后我们把该微调后的BERT-Base作为教师。通用TinyBERT为学生。我们进行蒸馏从微调的BERT-Base迁移知识到通用TinyBERT。在蒸馏后,我们的通用TinyBERT会包含来自教师的特定任务的知识,现在我们可以称该通用TinyBERT为一个微调的TinyBERT。

下面的表格帮我们排除一些关于通用蒸馏和任务特定蒸馏的困扰:

通用蒸馏(预训练阶段)

特定任务蒸馏(微调阶段)

教师

预训练的BERT-Base

微调的BERT-Base

学生

TinyBERT(小的BERT)

通用TinyBERT(预训练的TinyBERT)

结果

蒸馏完成后,学生BERT会包含从教师学来的知识,我们称该预训练的学生BERT为通用TinyBERT

蒸馏完成后,通用TinyBERT会包含来自教师任务相关的知识,我们称该通用TinyBERT为微调的TinyBERT,因为它已经为特定任务微调过了

既然需要在微调阶段进行蒸馏,我们可能需要更多与下游任务相关的数据。所以我们可以使用数据增强方法来获得增强的数据集。我们然后基于该数据集来微调通用TinyBERT。

数据增强方法

假设有一个句子:​​Paris is a beautiful city​​​。首先用BERT分词器(tokenizer)进行分词然后用列表​​X​​​来保持结果:​​X = [Paris, is, a, beautiful, city]​​。

我们再拷贝一份,得到​​X_masked=[Paris, is, a, beautiful, city]​​。

对于列表中的每个单词索引​​i​​,我们进行以下步骤:

  1. 检查​​X[i]​​​是否为single-piece单词,若是,我们用​​[MASK]​​​标记来掩码它。然后,我们使用BERT-Base模型来预测该掩码的单词。我们拿预测概率最大的K个单词作为候选单词列表​​candidates​​。假设K=5。
  2. 如果​​X[i]​​​不是single-piece单词,那么就不掩码。而是,我们利用glove嵌入查询与之最相似的K个单词存入候选单词列表。然后,我们从0到1的均匀分布中采用一个值几个知识蒸馏相关的BERT变体_DistilBERT_92。我们引入一个新变量称为阈值,几个知识蒸馏相关的BERT变体_概率分布_93。假设几个知识蒸馏相关的BERT变体_TinyBERT_94。然后我们进行下一步。
  3. 如果几个知识蒸馏相关的BERT变体_DistilBERT_92小于等于几个知识蒸馏相关的BERT变体_概率分布_93,那么从候选单词列表中随机抽取一个单词替换​​​X_masked[i]​​。
  4. 否则,什么都不做。

我们对句子中的每个单词 ​​i​​​ 执行前面的步骤,并将更新后的​​X_masked​​​列表添加到一个名为 ​​data_aug​​​ 的列表中。我们对数据集中的每个句子重复这种数据增强方法 几个知识蒸馏相关的BERT变体_编码器_39 次。假设 几个知识蒸馏相关的BERT变体_编码器_98,那么对于每个句子,我们执行数据增强步骤并获得10个新句子。

现在我们已经了解了数据增强方法是如何工作的,我们回到上面的例子:

X = [Paris, is, a, beautiful, city]

我们拷贝​​X​​​到新列表​​X_masked​​中:

X_masked = [Paris, is, a, beautiful, city]

现在,对于列表中的每个单词​​i​​,我们进行下面步骤:

如果​​i=0​​​,我们有​​X[0]=Paris​​​,我们看是否为single-piece单词,答案是肯定的,我们用​​[MASK]​​标记替换该单词:

X_masked = [[MASK], is, a, beautiful, city]

此时,我们使用BERT-Base模型预测最有可能是原词的K个单词,然后存入候选集列表。这里假设几个知识蒸馏相关的BERT变体_TinyBERT_99;假设我们得到的候选集列表如下:

candidates = [ Paris, it, that] 

接着,抽一个随机数几个知识蒸馏相关的BERT变体_编码器_100,假设几个知识蒸馏相关的BERT变体_编码器_101,因此我们从候选集中随机抽取单词替换它,假设我们抽取的单词为​​​it​​​,那么​​X_masked​​列表就变成了:

X_masked = [it, is, a, beautiful, city]

现在,我们可以把​​X_masked​​​加到​​data_aug​​列表中了。

这样,我们重复上面的步骤几个知识蒸馏相关的BERT变体_编码器_39次就可以得到几个知识蒸馏相关的BERT变体_编码器_39个句子加到增强数据集中。

有了这样的增强数据集,我们就可以对通用TinyBERT进行微调了。

简之,在 TinyBERT 中,我们在所有的层进行蒸馏,我们也在预训练和微调阶段都应用蒸馏。

TinyBERT的效率比BERT-Base模型高96%,小了7.5倍,推理速度快9.4倍。我们可以在​​这里​​下载经过预训练的TinyBERT。

我们已经知道如何从一个大的预训练的BERT蒸馏知识到小的BERT中,那么我们能否迁移知识到一个简单神经网络中呢?

从BERT到神经网络的知识迁移

Distilling Task-Specific Knowledge from BERT into Simple Neural Networks 论文阐述了如何运用知识蒸馏迁移BERT知识到简单神经网络中。

教师-学生架构

首先看一下这里教师和学生是怎样的架构。

教师BERT

可以使用预训练的BERT作为教师。这里,我们使用BERT-large作为教师。

首先我们基于相关任务微调该教师模型。

假设我们的下游任务为情绪分析任务。那么我们拿一个预训练的Bert-large模型,然后再情绪分析任务数据集上微调它,把微调后的Bert-large模型就可以作为教师了。

学生网络

在此任务中,学生网络就是一个简单的双向LSTM。我们来看一下单句分类任务的学生网络架构。

假设我们进行情绪分析任务,有一个句子:​​I Love Paris​​。首先我们得到该句子的嵌入向量,然后喂给Bi-LSTM,得到了双向的隐藏状态。

接着,我们将前向和反向的隐藏状态喂给带ReLU激活的全连接层,它输出logits,然后喂给softmax函数得到该句子属于正面还是负面的概率:

几个知识蒸馏相关的BERT变体_DistilBERT_104

现在,让我们来看看学解决句子匹配任务的学生网络架构。假设我们想了解给定的两句话是否相似。在这种情况下,我们的学生网络是孪生BiLSTM。

首先,我们得到sentence1和sentence2嵌入向量,然后分别喂给Bi-LSTM1和Bi-LSTM2,得到双向的隐藏状态。假设几个知识蒸馏相关的BERT变体_编码器_105几个知识蒸馏相关的BERT变体_DistilBERT_106​分别表示从Bi-LSTM1和Bi-LSTM2中获得的双向隐藏状态。然后我们使用一个如下拼接-比较操作:
几个知识蒸馏相关的BERT变体_编码器_107
其中几个知识蒸馏相关的BERT变体_TinyBERT_108表示元素级相乘。

下面,我们将拼接的结果喂给带有ReLU激活的全连接层,得到logits,然后喂给softmax函数得到属于相似还是不相似的概率:

几个知识蒸馏相关的BERT变体_编码器_109

我们已经看了两个学生网络架构的例子,下面我们看看如何训练学生网络。

训练学生网络

如上文所述,我们拿到一个预训练的BERT模型后,我们先在相关任务数据集上微调它,然后再把它作为教师模型。因此,教师模型是预训练的、微调的BERT模型。我们这里学生模型为BiLSTM。

那么损失函数是什么?它是学生损失和蒸馏损失的加权和:
几个知识蒸馏相关的BERT变体_DistilBERT_110
如果我们设几个知识蒸馏相关的BERT变体_编码器_111,那么变成了:
几个知识蒸馏相关的BERT变体_概率分布_112
我们知道蒸馏损失一般是软目标和软预测之间的交叉熵损失。但是这里,我们使用均方误差最为蒸馏损失,因为在这种设定中它比交叉熵损失表现要好:
几个知识蒸馏相关的BERT变体_概率分布_113
几个知识蒸馏相关的BERT变体_编码器_79表示教师网络的logits,几个知识蒸馏相关的BERT变体_TinyBERT_78表示学生网络的logits。

学生损失就是一般的在硬目标和硬预测之间的交叉熵损失了。

我们通过最小化损失函数几个知识蒸馏相关的BERT变体_DistilBERT_116来训练学生网络。为了能从教师网络蒸馏知识到学生网络,我们需要一个更大的数据集。

这里,我们使用一种任务无关(task-agnostic)的数据增强方法。

数据增强方法

我们使用下面的方法来进行任务无关的数据增强:

  • 掩码
  • 基于词性的单词替换
  • n-gram采样

掩码方法

在掩码方法中,具有概率几个知识蒸馏相关的BERT变体_TinyBERT_117,我们随机用​​​[MASK]​​​标记遮盖句子中的一个单词,然后就创建了一个带有掩码标记的新句子。例如,假设我们正在进行情感分析任务,并在我们的数据集中,我们有句子​​I was listening to music​​​。现在,基于概率几个知识蒸馏相关的BERT变体_TinyBERT_117,我们随机遮盖一个单词。假设我们掩盖了music这个词,然后我们有一个新的句子:​​​I was listening to [MASK]​​。

但这样有用吗?带有​​[MASK]​​​标记的句子,我们的模型无法产生确信的logits因为​​[MASK]​​是一个未知标记。我们的模型基于被遮盖的句子产生的是一个相对欠确信的logits。这有助于模型理解每个单词对于句子所属标签的重要度。

基于词性的单词替换方法

在此方法中,有一个概率几个知识蒸馏相关的BERT变体_TinyBERT_119,我们用同词性单词替换某个单词。

比如,考虑句子​​Where did you go?​​​,我们知道这里​​did​​​是动词。现在我们可以用另一个动词来替换它。所以现在句子变成了​​Where do you go?​​,这样我们就得到了一个新句子。

n-gram采样方法

在此方法中,也有一个概率几个知识蒸馏相关的BERT变体_损失函数_120。我们只是从句子中随机生成一个n-gram,这里的几个知识蒸馏相关的BERT变体_TinyBERT_121从1到5之间随机选择。

我们已经了解了三种数据增强方法。但是如何应用它们呢?

数据增强过程

假设我们有一个句子​​Paris is a beautiful city​​​。令几个知识蒸馏相关的BERT变体_损失函数_122代表句子中的单词。现在对于句子中的每个单词几个知识蒸馏相关的BERT变体_编码器_123,我们创建一个叫做几个知识蒸馏相关的BERT变体_损失函数_124的变量,这里几个知识蒸馏相关的BERT变体_损失函数_124从0到1之间的均匀分布中随机生成。基于几个知识蒸馏相关的BERT变体_损失函数_124,我们进行如下操作:

  • 如果几个知识蒸馏相关的BERT变体_TinyBERT_127,那么我们遮盖单词几个知识蒸馏相关的BERT变体_损失函数_128
  • 如果几个知识蒸馏相关的BERT变体_TinyBERT_129,那么我们应用基于词性的替换方法

注意这两种方法是排他的,我们不能同时使用这两种方法。

这样,我们得到了一个修改了的合成句子。现在, 有概率几个知识蒸馏相关的BERT变体_损失函数_120,我们应用n-gram采样到这个合成句子上来获得最终的合成句子,然后加它到​​​data_aug​​中。

对于每个句子,我们执行前面的步骤几个知识蒸馏相关的BERT变体_编码器_39次数并获得几个知识蒸馏相关的BERT变体_编码器_39个新的合成句。好吧,如果我们处理的是句子对,那么我们如何获得合成句子对?在这种情况下,我们可以创建具有许多组合的合成句对。其中一些如下:

  • 我们可以只从第一个句子中创建合成句子,保持第二个句子不变
  • 保持第一句子不变,只从第二个句子中创建合成句子
  • 同时从这两个句子中创建合成句子

这样我们得到了更多的数据。

References

  1. Distilling the Knowledge in a Neural Network
  2. DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
  3. TinyBERT: Distilling BERT for Natural Language Understanding
  4. Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
  5. Getting Started with Google BERT