BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences

相关链接:arxiv github office 关键字:BurstAttentionDistributed ComputingLong SequencesTransformerLarge Language Models (LLMs)

摘要

有效的注意力模块在Transformer基础的大型语言模型(LLMs)的成功中发挥了关键作用,但这些注意力模块的二次时间和内存复杂性也在处理长序列时构成了挑战。一种潜在的解决长序列问题的方法是使用分布式集群在多个设备(例如GPU)上并行计算注意力模块。然而,采用分布式方法不可避免地引入额外的内存开销来存储局部注意力结果,并且产生额外的通信成本来将局部结果汇总成全局结果。在本文中,我们提出了一个名为“BurstAttention”的分布式注意力框架,以优化全局集群和本地设备级别的内存访问和通信操作。在我们的实验中,我们将BurstAttention与其他竞争的分布式注意力解决方案进行了比较,这些解决方案用于长序列处理。在不同长度设置下的实验结果表明,与这些竞争基线相比,BurstAttention在处理长序列方面提供了显著优势,减少了40%的通信开销,并在8个A100上以32k序列长度训练期间实现了2倍的加速。

核心方法

image.png

BurstAttention的核心方法包含:

  1. 全局集群优化:针对分布式计算过程中的全局通信开销进行优化,减少数据传输量。
  2. 本地设备优化:在单个设备内部优化注意力模块的计算和数据组织方式,以利用局部性原理减少内存访问的时间和能耗。
  3. 注意力计算并行化:采用并行策略对注意力计算进行拆分,增加并行度,实现性能提升。

实验说明

训练加速性能比较

序列长度 BurstAttention 竞争基线 加速比 通信开销降低
32K ✔️ 2倍 40%

实验表明,在处理32k长度的长序列时,BurstAttention相比于其他竞争的分布式注意力解决方案,在训练加速和通信开销两个方面都有显著的优势。

结论

BurstAttention作为一种高效的分布式注意力框架,能够显著提高处理极长序列的效率。它通过优化全局集群和本地设备上的内存访问和通信操作,降低了通信成本,并在8 X A100GPU上处理32K长度序列时实现了显著的训练加速。这表明BurstAttention是一个解决长序列问题的有力工具,有望为处理长输入序列的Transformer模型和其他相关模型带来实质性的性能提升。