近日,上海高级算法研究院等机构的研究人员发表了一篇题为"Memory³: Language Modeling with Explicit Memory"的论文,提出了一种新颖的显式记忆机制,用于提升大语言模型(LLM)的性能和效率。本文将对这篇论文的核心思想和技术细节进行详细解读。
1. 研究背景与动机
近年来,大语言模型(如GPT-3、LLaMA等)在各种自然语言处理任务上取得了惊人的成果。然而,这些模型也面临着一些挑战:
- 训练和推理成本高昂:需要大量计算资源来训练和部署这些模型。
- 知识利用效率低:每次生成token时都会激活所有参数,造成大量冗余计算。
- 长文本处理能力有限:由于self-attention的计算复杂度,难以处理超长文本。
为了解决这些问题,研究人员提出了Memory³模型,其核心思想是为LLM引入显式记忆机制,将部分知识从模型参数中外部化,从而实现更高效的知识存储和访问。
2. Memory³模型的核心设计
2.1 记忆层次结构
Memory³引入了三种记忆形式,形成了一个记忆层次结构:
- 隐式记忆(Implicit Memory):模型参数,用于存储抽象知识和高频使用的具体知识。
- 显式记忆(Explicit Memory):本文提出的新型记忆形式,介于模型参数和外部文本之间。
- 外部信息(External Information):用于检索增强生成(RAG)的原始文本。
这种层次结构类似于人脑的记忆机制,允许模型根据知识的使用频率和抽象程度,将其存储在最合适的记忆形式中。
2.2 显式记忆机制
显式记忆的核心思想是:
- 写入:在推理之前,将参考文本转换为显式记忆,存储在硬盘上。
- 读取:推理时,检索相关的显式记忆,并通过self-attention机制整合到计算中。
具体实现:
- 显式记忆是从参考文本编码得到的注意力key-value向量的子集。
- 使用稀疏化技术大幅压缩存储空间(压缩率可达160倍)。
- 推理时,每生成64个token就检索5个新的显式记忆。
2.3 模型架构
Memory³模型基于Transformer架构,主要修改包括:
- 前半部分的attention层被设置为"记忆层",可以生成和访问显式记忆。
- 使用分组查询注意力(GQA)来减少key-value头的数量。
- 引入并行位置编码,避免"中间丢失"现象。
2.4 两阶段预训练
Memory³采用了一种新颖的两阶段预训练方案:
- 预热阶段:不使用显式记忆,类似传统LLM预训练。
- 持续训练阶段:引入显式记忆,模型学习如何生成和利用这些记忆。
这种设计的动机是:模型需要先建立基本的语言理解能力,才能有效地利用显式记忆。
3. 技术细节
3.1 显式记忆的稀疏化
为了减少存储空间和计算开销,Memory³对显式记忾进行了多维度的稀疏化:
- 层:只在前半部分的attention层使用记忆机制。
- 头:通过GQA减少key-value头的数量(8个vs. 40个)。
- token:每个key-value头只选择8个最重要的token(从128个中选择)。
- 向量维度:可选地使用向量量化器进行压缩(压缩率约11.4倍)。
通过这些技术,将原本7.17PB的显式记忆库压缩到了45.9TB或4.02TB。
3.2 模型规模设计
研究者提出了一个有趣的观点:模型参数主要用于存储抽象知识,而具体知识可以外部化为显式记忆。基于这一思想,他们设计了一个独特的模型规模选择方法,最终确定的模型结构为:
- 44个Transformer块
- 40个查询头,8个key-value头
- 头维度80,隐藏维度3200
- 词表大小60416
总的非嵌入参数量为2.4B。
3.3 训练数据处理
研究者采用了多步骤的数据处理流程:
- 收集:来自多个公开数据集,包括英文和中文文本。
- 去重:使用MinHash算法。
- 规则过滤:设计了多个启发式规则来移除低质量文本。
- 模型打分:使用微调过的BERT模型为文本质量打分。
最终的预训练数据集包含约4万亿个token。
3.4 知识库构建
为了支持显式记忾机制,研究者构建了一个包含1.1亿个参考文本块的知识库。每个文本块长度不超过128个token,主要来源于高质量的文本数据,如维基百科、新闻和学术书籍等。
4. 实验结果
研究者进行了全面的实验评估,主要结果包括:
- 一般能力:在多个标准基准测试(如MMLU、HellaSwag等)上,Memory³-2B模型的表现超过了许多参数量更大的模型,如Llama2-13B。
- 对话能力:在MT-Bench上的得分超过了Vicuna-7B、Falcon-40B-Instruct等模型。
- 专业任务:在法律和医学领域的任务上,Memory³-2B的表现优于多个RAG模型。
- 事实性与幻觉:在TruthfulQA等评估幻觉的数据集上,Memory³-2B取得了最好的成绩。
- 推理速度:在使用显式记忆的情况下,Memory³-2B的推理速度仍快于大多数RAG模型。
5. 创新点与贡献
- 提出了显式记忆这一新型记忆形式,为LLM构建了一个类人脑的记忆层次结构。
- 设计了高效的显式记忆稀疏化方法,大幅减少了存储和计算开销。
- 提出了基于知识分布的模型规模设计方法,实现了更高效的参数利用。
- 开发了两阶段预训练方案,有效地将显式记忆整合到模型训练中。
- 在多个任务上展示了显式记忆的优势,特别是在提高事实性和减少幻觉方面。
6. 局限性与未来工作
尽管Memory³取得了令人印象深刻的成果,研究者也指出了一些局限性和未来的研究方向:
- 预训练过程中出现了不可修复的损失发散问题,导致训练提前终止。未来需要进一步优化训练稳定性。
- 当前的显式记忆更接近人类的情景记忆,未来可以探索如何构建更抽象的语义记忆。
- 可以进一步优化显式记忆的检索机制,例如使用模型内部的隐藏特征来进行检索。
- 探索如何在模型参数更新(如微调)后保持显式记忆的可读性。
- 研究如何更有效地学习抽象知识,进一步提高训练效率。
7. 总结
Memory³为大语言模型引入了一种新颖的显式记忆机制,通过构建类人脑的记忆层次结构,实现了更高效的知识存储和访问。这一方法不仅提高了模型性能,还显著降低了计算成本,为未来大语言模型的发展提供了新的思路。尽管仍有一些待解决的问题,但Memory³的成功无疑为LLM领域注入了新的活力,值得学术界和工业界的进一步关注与探索。