CRINEG Loss:学习什么语言不建模 《The CRINGE Loss:Learning what language not to model》

论文地址:https://arxiv.org/pdf/2211.05826.pdf

相关博客
【自然语言处理】【文本生成】CRINEG Loss:学习什么语言不建模【自然语言处理】【文本生成】使用Transformers中的BART进行文本摘要【自然语言处理】【文本生成】Transformers中使用约束Beam Search指导文本生成【自然语言处理】【文本生成】Transformers中用于语言生成的不同解码方法【自然语言处理】【文本生成】BART:用于自然语言生成、翻译和理解的降噪Sequence-to-Sequence预训练【自然语言处理】【文本生成】UniLM:用于自然语言理解和生成的统一语言模型预训练【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态

一、简介

自然语言模型的训练_自然语言模型的训练

近些年来,随着自然语言模型的训练_损失函数_02的崛起,语言模型和对话代理变的越来越强大,以至于可以进行各种交互。然而,采用标准的语言模型训练、缩放模型尺寸和大量的训练数据仍然有大量的问题无法解决。特别地,模型仍然会遭受毒性和偏见的影响、缺乏连贯性以及不能解决用户的意图。相反,越来越多的工作正在研究如何将这些信息融入训练目标中,从而超越标准的语义建模目标函数。

在本文中,研究一种设定:训练集中包含一组正样本(语言模型训练中经常使用)和一组负样本(模型不应该生成的)。本文提出了新的学习方法,自然语言模型的训练_自然语言处理_03损失函数,作为一种在这种数据上训练且概念简单的方法,其易于实现并比现有的方法性能好。正样本使用常用的最大似然方法。负样本的训练则是受Jiang et al.启发并进行泛化的"监督对比学习目标函数",并且不需要任何架构的改变,仅对损失函数进行最小的改变。上图1展示了自然语言模型的训练_损失函数_04损失函数在单个负序列上的概念图。由于该损失函数允许在负样本上高效的训练,通过在模型自己生成样本分类上训练来迭代的改善生成结果。

本文在三个具有正、负训练样本的任务上展示了该方法的效果。三个任务分别是:完全生成任务自然语言模型的训练_文本生成_05、矛盾避免任务自然语言模型的训练_自然语言处理_06和开发域任务导向对话任务自然语言模型的训练_自然语言模型的训练_07。此外,本文与广泛的baseline进行了比较。一般来说,自然语言模型的训练_损失函数_04损失函数单次迭代就超过了大多数的baselines。在本文提出的迭代方式上应用自然语言模型的训练_损失函数_04,可能看到额外的改善,并在所有的三个任务上带来最优的效果。

二、相关工作

1. 使用负样本训练

使用负样本来训练语言模型有几种实现方式。Welleck et al.提出了unlikelihood训练,其在训练目标中添加了一个额外的项,降低了负token相较于其他token的概率。他们表明这是检索语言模型中重复生成的有效方法。Jiang et al.也提出了对比学习目标函数来缓解文本退化。他们认为将前面自然语言模型的训练_损失函数_10个上下文token与正标签对比,有助于避免不期望的token,相较于unlikelihood训练。这个方法在减少正样本生成时的重复上很有效,但是其不能在任意的负样本上工作,因为其对任意给定的负token需要正确的正token。本文当前的工作就是受该方法的启发,并将其泛化到负样本训练上。

一个完全不同但流行的从负样本中学习的方法是训练一个分类器或者重排模型。本文中,并不是去更新模型参数,而是训练一个额外的模型来评估生成。通过使用语言模型来生成多个候选,重排器去决定那个候选分数最高。Nie et al.训练重排器以避免生成矛盾的问题。Nakano et al.发现在某些场景中重排可能优于强化学习。

模型指导方法,像自然语言模型的训练_深度学习_11自然语言模型的训练_自然语言处理_12自然语言模型的训练_自然语言模型的训练_13自然语言模型的训练_深度学习_14,在解码过程中对每个token都使用该模型,而不是在最终的生成上使用额外的模型。因此,第二个模型期望的属性被用于指导语言模型的生成。近期的自然语言模型的训练_损失函数_15模型不使用第二个模型,在相同的架构上共享语言模型和分类指导头。其在多个任务上都工作的很好,但是缺点是需要架构变化并且不能够轻易的应用在现有的模型上。

2. 迭代训练语言模型

Unlikelihood训练通过在模型自己的生成样本上训练来迭代的改善重复生成的问题。在人类偏好上训练语言模型已经成功应用在摘要、对话等任务上。Lu et al.使用生成的样本训练一个消除不期望行为的语言模型。他们标记和量化模型生成的样本,并通过在序列前添加奖励token来执行条件训练。自然语言模型的训练_损失函数_16模型来自人类反馈的强化学习来将语言模型对齐至instructions

三、自然语言模型的训练_损失函数_04损失函数

自然语言模型的训练_损失函数_04损失函数是一种在正、负序列上训练模型的方法。对于正样本,利用常见的最大似然方法。负样本则通过语言模型预测的top token之一和序列的每个token进行对比。上图展示了如何在负样本序列上进行训练。

更正式的来说,最终的优化目标由两项组成:正样本序列的交叉熵和负样本序列的自然语言模型的训练_损失函数_04。前者就是标准的使用方法,对于来自正样本序列自然语言模型的训练_深度学习_20的token 自然语言模型的训练_深度学习_21
自然语言模型的训练_深度学习_22
其中自然语言模型的训练_自然语言模型的训练_23表示模型为token 自然语言模型的训练_自然语言模型的训练_24输出的logit。对于负样本,将序列中的每个token与正token进行对比。在训练数据中通常会提供负样本序列,但是不知道对于给定序列中的负token,其可替换的正token应该是哪个。本文的方法是从模型当前的top-k预测中采样(若负token在top-k中,则忽略负token)。这里,通过模型预测的top-k logitssoftmax构造的类别分布进行采样。选择对比loss为
自然语言模型的训练_文本生成_25
其中,自然语言模型的训练_自然语言模型的训练_26表示提供的负样本标注token的logit分数,自然语言模型的训练_文本生成_27从模型top-k预测中采样的正token。该方法背后的直觉是使用模型作为近似的数据库来提供可选的正token。或者,从另一个角度看,确保已知的负token排名低于模型认为排名top-k的token。算法1提供了单次预测的伪代码。

算法1:对于单个负token的自然语言模型的训练_自然语言处理_28损失函数

需要:一个token索引序列自然语言模型的训练_自然语言模型的训练_29、一个标注的负token 自然语言模型的训练_自然语言处理_30、一个生成模型自然语言模型的训练_文本生成_31、一个标量自然语言模型的训练_文本生成_32

  • 将序列输入至模型并得到每个token的分数,即自然语言模型的训练_深度学习_33
  • 获得索引不为自然语言模型的训练_损失函数_34的top-k token预测分数,即自然语言模型的训练_自然语言处理_35
  • 从上面的集合中采样正token,即自然语言模型的训练_自然语言模型的训练_36
  • 拼接正、负token分数,并按照自然语言模型的训练_深度学习_37计算损失函数,即自然语言模型的训练_文本生成_38

为了同时在正负样本上进行训练,采用两个损失函数的加权求和
自然语言模型的训练_损失函数_39
其中,自然语言模型的训练_自然语言模型的训练_40是控制负样本影响的超参数。自然语言模型的训练_损失函数_04损失函数很容易实现,并仅需要在损失函数上进行简单的修改。

自然语言模型的训练_损失函数_04迭代训练

提出的自然语言模型的训练_损失函数_04损失函数允许在正样本和负样本上高效的训练模型。这使得通过学习自己生成的分类来迭代改善模型成为可能。这里遵循简单的策略,完成模型训练,标记模型在训练集上的生成,然后使用增强的训练集重复这个过程。模型生成的标签可以在整个循环中人工评估来获得,本文提出在原始的正、负样本上训练分类器,并使用自动标记的样本,类似于强化学习中的奖励模型。使用以下流程:

  1. 在数据集自然语言模型的训练_自然语言处理_44上微调模型;
  2. 基于原始的训练样本上下文,使用模型来生成额外的序列;
  3. 标注模型生成的正样本或者负样本,将其作为额外样本添加至数据集自然语言模型的训练_自然语言处理_44
  4. 使用更新后的数据重复上面的过程。

在本文的实验中,发现即使仅迭代两轮也能够带来显著的效果改善。伪代码如算法2所示。

算法2自然语言模型的训练_自然语言处理_28训练循环

需求:一个具有正、负序列的数据集自然语言模型的训练_自然语言处理_47,一个生成模型自然语言模型的训练_文本生成_31,一个为文本序列分配二元标签的函数自然语言模型的训练_深度学习_49(一个人或者一个在自然语言模型的训练_自然语言处理_47上训练的分类器)

  • 使用原始数据集初始化自然语言模型的训练_自然语言处理_51,即自然语言模型的训练_自然语言处理_52
  • for Iterations=1,N do
  • 自然语言模型的训练_自然语言模型的训练_53损失函数在数据集自然语言模型的训练_自然语言处理_51上训练模型直至收敛,即自然语言模型的训练_损失函数_55
  • 自然语言模型的训练_损失函数_56的提示(prompts)生成序列,即自然语言模型的训练_自然语言模型的训练_57
  • 自然语言模型的训练_文本生成_58
  • 自然语言模型的训练_文本生成_59