这是一个大佬手撕Flash Attention的呀~~

本文重心在于介绍 Flash Attention 的算法思想及其实现方式,并对提高 Transformer 运算效率的相关工作做简要介绍。

前言

自 2022 年 11 月 OpenAI 发布 ChatGPT 以来,这一年多来大语言模型 (Large Language Model, LLM) 的发展十分迅速,国内外众多厂商纷纷加入“百模大战”。但是,由于大语言模型的参数量非常巨大(通常为十亿、百亿甚至千亿量级),加之训练语料很庞大,模型的训练成本十分高昂。

当前,Transformer 已经成为了大语言模型的默认网络结构,为了降低大语言模型的训练成本,一些工作尝试对 Transformer 的计算成本进行优化,比如降低注意力运算的时间成本或者显存占用等。

本文介绍 Flash Attention,一种优化的注意力算法。Flash Attention 论文链接如下:

https://arxiv.org/pdf/2205.14135

本文从注意力机制出发,分析原始的注意力机制为什么需要优化,并简要介绍前人在优化注意力机制方面做了哪些工作,再介绍 Flash Attention,并基于 Numpy 手把手实现 Flash Attention 的主体算法

本文所有代码已开源:

https//gist.github.com/xiabingquan/a4a9a743f97aadd531ed6218be20afd2

如有写得不对或者不清楚的地方还请不吝赐教,在此谢过!

感谢以下用户的指正:@INTuition 

由于博主缺少 MLSys 相关背景,因此本文重心在于介绍 Flash Attention 的算法思想及其实现方式,并对提高 Transformer 运算效率的相关工作做简要介绍,而 IO 复杂度分析等内容则略过。本文末尾附了一些其他博主写的个人觉得比较好的讲解Flash Attention的文章链接,读者阅读完本文之后可以作为补充阅读。

本文共约 1.4w 字,阅读约需要 30 分钟。

本文的组织结构如下(PC 端点击左侧目录可直接跳转):

  • Transformer 简介:简单介绍 Transformer 的基础知识,以介绍 self-attention 为主;
  • Attention 为什么慢:介绍 Transformer 中的 attention 的速度瓶颈;
  • 如何提高 Transformer 的计算效率:简单介绍提高 Transformer 计算效率的相关工作;
  • Flash Attention:进入正题,详细介绍 Flash Attention 的算法思想和细节;
  • 实验效果:简单介绍 Flash Attention 的实际效果;
  • 总结:本文总结。

Transformer 简介

本节介绍 Transformer 的基础知识。由于除注意力机制以外的其他内容和本文内容无关,因此本节主要介绍注意力机制。Transformer的详细解释及其代码实现可参考这篇文章(https://zhuanlan.zhihu.com/p/648127076)。

Transformer 是深度学习领域一种非常流行的模型结构,由 Ashish Vaswani 等人于2017年提出[1],主要用于序列到序列 (sequence-to-sequence)[2] 相关任务,如机器翻译、语音识别等。Transformer 主要基于注意力机制搭建,不使用循环神经网络 (RNN) 和卷积神经网络 (CNN) 等结构。

Transformer 包括编码器和解码器两部分,由于当前主流的大语言模型几乎都基于只含解码器而不含编码器的仅解码器 (decoder-only) 模型,因此此处主要介绍仅解码器模型中的 Transformer 解码器,该解码器通过多个解码器层堆叠而成,每层包含自注意力层、前馈神经网络、层归一化、残差连接等组件。

其中,自注意力层接收一个特征序列作为输入,并将该序列输入作为查询 (Query, 下文简称 Q)、键 (Key, 下文简称 K) 和值 (Value, 下文简称 V),使用缩放点积 (Scaled-dot Production) 来计算 Q 和 K 之间的注意力权重矩阵,然后再通过注意力权重和 V 来计算自注意力层的输出。

自注意力层的主体代码如下。简单起见,此处省略自注意力层中的 Q、K、V 各自的线性映射、Dropout、多头注意力、掩码机制等内容。

import unittest

import torch
import torch.nn as nn
from torch.nn import functional as F

class StandardAttention(object):
    def __init__(self) -> None:
        """
        Attention module implemented in Numpy.

        Formula:
            P = QK^T
            S = softmax(P / sqrt(d_k))
            O = SV

        Reference:
            <<Attention Is All You Need>>
        URL:
            https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

        """
        pass

    def _validity_check(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> None:
        assert q.ndim == 3, "q should be a 3D tensor"      # [batch_size, seq_len, hidden_size]
        assert k.ndim == 3, "k should be a 3D tensor"
        assert v.ndim == 3, "v should be a 3D tensor"
        assert q.shape[0] == k.shape[0], "batch_size of q and k should be the same"
        assert q.shape[2] == k.shape[2], "hidden_size of q and k should be the same"
        assert q.shape[2] == v.shape[2], "hidden_size of q and v should be the same"

    def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> np.ndarray:
        self._validity_check(q, k, v)
        batch_size, q_len, hidden_size = q.shape
        denom = np.sqrt(hidden_size)
        attn = np.matmul(q, k.transpose(0, 2, 1))       # [batch_size, q_len, k_len]
        attn = np.exp((attn - attn.max(axis=-1, keepdims=True)) / denom)
        attn = attn / attn.sum(axis=-1, keepdims=True)
        out = np.matmul(attn, v)                        # [batch_size, q_len, hidden_size]
        return out

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


def self_attention(x):
    return StandardAttention()(x, x, x)


class TestSelfAttention(unittest.TestCase):
    def test_forward(self):
        input_dim = 10
        batch_size = 32
        seq_len = 20

        x = torch.randn(batch_size, seq_len, input_dim)
        output = self_attention(x)
        expected = F.scaled_dot_product_attention(x, x, x)

        self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))

if __name__ == '__main__':
    unittest.main()

我们可以通过 PyTorch 库所给的 F.scaled_dot_production 函数来验证 self_attention 函数的正确性。单元测试的结果此处略过。

Attention 为什么慢?

上一节提到,Transformer 的主要组成部分为 attention,因此优化 Transformer 重点在于优化 attention 的计算。那么,attention 为什么需要优化呢?或者说,注意力机制为什么慢?

此处的“快慢”是相对而言的。严格意义上来说,相比于传统的 RNN,Transformer 中的 attention 可以并行地处理序列所有位置的信息(RNN 只能串行处理),因此计算效率并不低,但是仍然有可以进一步改进的空间。

众所周知,对于科学计算程序而言,按照算数运算和内存读取各自所花的时间比例,科学计算通常分为计算密集型 (compute-bound) 和内存密集型 (memory-bound) 两类。其中,计算密集型运算的时间瓶颈主要在于算数计算,比如大型矩阵的相乘等,而内存密集型运算的时间瓶颈主要在于内存的读写时间,比如批归一化、层归一化等等

我们可以从计算和内存两方面来分析“attention为什么慢”这个问题,分别对应于时间复杂度和空间复杂度两个方面。

Flash Attention~2_权重

图1. GPU的内存层级。图源:Flash Attention原文

如图 1 所示,GPU 的内存可以分为 HBM 和 SRAM 两部分。例如,A100 GPU 具有 40-80 GB 的高带宽内存 (上图中的 HBM,即我们平时说的“显存”),带宽为 1.5-2.0 TB/s,并且每个流式多处理器都有 192 KB 的片上 SRAM,带宽约为 19 TB/s。片上 SRAM 比 HBM 快一个数量级,但容量要小很多个数量级。在 GPU 运算之前,数据和模型先从 CPU 的内存(上图中的 DRAM)移动到 GPU 的 HBM,然后再从 HBM 移动到 GPU 的 SRAM,CUDA kernel 在 SRAM 中对这些数据进行运算,运算完毕后将运算结果再从 SRAM 移动到 HBM。

将 HBM 和 SRAM 之间的数据交换考虑在内,attention 的计算过程可以用如下图所示的算法表示。

Flash Attention~2_人工智能_02

图2. 考虑数据交换的Attention算法。图源:Flash Attention原文

通过前面的空间复杂度分析,attention 运算需要占据的显存空间随着序列长度 nn 的增长呈平方级增长。由于运算需要在 GPU 的 SRAM上 完成,这一过程需要不停地在 HBM 和 SRAM 之间交换数据,因此会导致大量的时间都消耗在 SRAM 和 HBM 之间的数据的换入换出上。

综合上面的关于 attention 的时间和空间复杂度的分析,为了加速 attention 运算,我们可以从降低时间复杂度和降低空间复杂度两个角度入手,接下来逐一进行介绍部分相关工作。

如何提高 attention 的计算效率

本节简单介绍提高 attention 运算效率的一些相关工作。本节内容主要是为了内容的完整性考虑,和 Flash Attention的具体内容无关,不影响后文 Flash Attention 的理解。

路径1:降低 attention 的计算复杂度

计算复杂度方面,一些工作尝试提出近似的 attention 算法,来降低 attention 的理论上的计算复杂度。主要可以分为稀疏 (sparse) 估计、低秩 (low-rank) 估计等。

Flash Attention~2_空间复杂度_03

虽然降低 attention 的计算复杂度在理论上非常具有吸引力,但是在实际应用中仍然存在一些短板,比如以下两点:

  • 性能比不上原始 attention。不论是稀疏估计、低秩估计还是其他,这些方法都采用了某种近似算法来估算注意力权重矩阵,难免会丢失信息。目前主流的还是原始的attention;
  • 无法减少内存读取的时间消耗。这些方法只能降低 attention 的计算复杂度,但是无法对 attention 运算过程中的空间复杂度等进行控制,无法减少内存读写带来的时间损耗。

路径2:降低attention的空间复杂度

空间复杂度方面,这方面工作的基本思路是降低 attention 对于显存的需求,减少 HBM 和 SRAM 之间的换入换出,进而减少 attention 运算的时间消耗。

值得一提的是,“减少 attention 对于显存的需求”和“减少 HBM 和 SRAM 之间的换入换出”这两者之间并不等价,前者重点在于减少显存消耗,比如 memory-efficient attention(https//arxiv.org/pdf/2112.05682),而后者重在降低数据交换的时间成本,比如 <<DATA MOVEMENT IS ALL YOU NEED: A CASE STUDY ON OPTIMIZING TRANSFORMERS>>(https//proceedings.mlsys.org/paper_files/paper/2021/file/bc86e95606a6392f51f95a8de106728d-Paper.pdf) 这篇文章。

为降低空间复杂度,一种具有代表性的方法是 kernel fusion。kernel fusion 的思想很简单,即将需要通过多个 CUDA kernel 来分步完成的操作融合到一个或者少数几个 CUDA kernel,从而减少数据在HBM和SRAM之间换入换出的次数,进而节省运算时间。

Flash Attention~2_权重_04

Flash Attention 的做法其实也是 kernel fusion,只是对应的 kernel 专门针对数据的换入换出进行了优化 (IO-aware),尽可能最小化 HBM 和 SRAM 之间的数据交换次数。

Flash Attention 介绍

本节介绍 Flash ttention 的动机、具体方法和实现细节,并基于 Numpy 实现 Flash Attention 的主体算法(代码已开源,链接(https//gist.github.com/xiabingquan/a4a9a743f97aadd531ed6218be20afd2))。

本文以实现 Flash Attention的前向过程为主,后向传播、masking、Dropout 等略过。

和 Transformer 的原始 attention 相比,Flash Attention 有以下三点特点:

  • 运算速度更快 (Fast);
  • 更节省显存 (Memory-Efficient);
  • 计算结果相同 (Exact)

这三点刚好和 Flash Attention 论文名《FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness》相对应。得益于 Flash Attention 的这几点特性,自 PyTorch 2.0(https//pytorch.org/blog/accelerated-pytorch-2/) 开始,Flash Attention 已经被集成到 PyTorch 官方库中,使用者可以直接通过 torch.nn.functional.scaled_dot_product_attention(https//pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) 进行调用。

摘要

Flash Attention 的动机是尽可能避免大尺寸的注意力权重矩阵在 HBM 和 SRAM 之间的换入换出。具体方法包含两个部分:tiling 和 recomputation

tiling 的基本思路:不直接对整个输入序列计算注意力,而是将其分为多个较小的块,逐个对这些块进行计算,增量式地进行 softmax 的规约。规约过程中只需要更新某些中间变量,不需要计算整个注意力权重矩阵

recomputation 的基本思路:基于 tiling 技巧,在反向传播过程中不保留整个注意力权重矩阵,而是只保留前向过程中 tiling 的某些中间变量,然后在反向传播过程中重新计算注意力权重矩阵。recomputation 可以看作是一种基于 tiling 的特殊的 gradient checkpointing,因此后文主要介绍 tiling,想进一步了解 recomputation 的读者可以翻阅 Flash Attention 原文。

得益于上述技巧,Flash Attention 可以同时做到又快(运算速度快)又省(节省显存)。

基于Tiling技巧的Softmax

本节主要介绍 Flash Attention 中用到的 tiling 技巧。Tiling 技巧不是 Flash Attention 的首创,该技巧在之前的工作中已有探索[3][4][5]。

Tiling 技巧的核心思想是,尽可能避免对整个序列进行操作,而是通过维护一些中间变量来递推式地完成某些操作,从而减少内存的消耗。

以 softmax 为例,原始的 softmax 可以用如下算法表示:

Flash Attention~2_空间复杂度_05

图3. 原始softmax。图源:《Online normalizer calculation for softmax》

该算法的实现如下。为了展示 softmax 运算的详细过程,以下代码没有使用 PyTorch、Numpy 等科学计算库,或者Python原生的 max、min 等归约函数,而仅仅使用 Python 原生的数值运算符对浮点数的列表进行操作。

class SoftMax(object):
    def forward(self, x: List[float]):

        # loop 1: get the maximum value
        max_x = -np.inf
        for t in x:
            max_x = t if t > max_x else max_x

        # loop 2: get the accumulative sum of exp(x_i - x_max)
        accum_exp = 0.
        for t in x:
            accum_exp += np.exp(t - max_x)

        # loop 3: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp`
        output = [0. for _ in range(len(x))]
        for i, t in enumerate(x):
            output[i] = np.exp(t - max_x) / accum_exp

        return output

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

从上面的代码可以看出,softmax 函数需要三个循环,第一个循环计算数组的最大值,第二个循环计算 softmax 的分母,第三个循环计算 softmax 输出。

使用 tiling 技巧的 softmax 的算法如下图所示。

Flash Attention~2_权重_06

图4. 使用tiling技巧的softmax。图源:《Online normalizer calculation for softmax》

该算法的实现如下:

class SoftMaxWithTiling(object):
    def forward(self, x: List[float]):
        # loop 1: get the maximum value of x and the accumulated exponential values
        max_x = -np.inf
        accum_exp = 0.
        for t in x:
            max_x_new = t if t > max_x else max_x
            accum_exp = np.exp(max_x - max_x_new) * accum_exp + np.exp(t - max_x_new)
            max_x = max_x_new

        # loop 2: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp`
        out = [0. for _ in range(len(x))]
        for i, t in enumerate(x):
            out[i] = np.exp(t - max_x) / accum_exp

        return out

单元测试的代码如下,单元测试的结果此处略过。

class SoftMaxTest(unittest.TestCase):
    def test_softmax(self):

        n_test = 10
        for _ in range(n_test):
            n_elem = np.random.randint(1, 11)
            x = np.random.randn(n_elem).tolist()
            expected = torch.nn.functional.softmax(torch.tensor(x), dim=-1).tolist()

            out = SoftMax()(x)
            self.assertTrue(np.allclose(expected, out, atol=1e-4))

            out_with_tiling = SoftMaxWithTiling()(x)
            self.assertTrue(np.allclose(expected, out_with_tiling, atol=1e-4))


if __name__  == "__main__":
    unittest.main()

Flash Attention~2_权重_07

Flash Attention~2_空间复杂度_08

通过 tiling 的方式,softmax 的循环数从三个减到了两个,从而可以降低内存消耗。

Flash Attention的Numpy实现

Flash Attention 同样基于上述的tiling技巧实现,但是和上述的 sofmax 有两点不同:

  • attention 的计算过程需要对 QQ 和 KK 进行内积,并且需要维护 attention 的输出矩阵 OO ;
  • 在上述 tiling 形式的 softmax 中,我们的每一步只更新一个元素,但是 Flash Attention 将输入分为多个块,每个块包含多个元素。

Flash Attention 的完整算法如图5所示。

Flash Attention~2_换出_09

图5. Flash Attention完整算法。图源:Flash Attention原文

由于我们无法直接从 Python 层面在 GPU 的 SRAM 和 HBM 之间进行数据交换,因此我们使用 load 和 write 方法来分别模拟 HBM -> SRAM 和 SRAM -> HBM 的数据传输过程:

def load(self, arr, st, ed, step):
    # Simulate the process that moves data from HBM to SRAM
    return arr[:, st * step: ed * step]

def write(self, arr, val, st, ed, step):
    # Simulate the process that moves data from SRAM to HBM
    arr[:, st * step: ed * step] = val

接下来去我们结合代码来逐步理解该算法:

Flash Attention~2_权重_10

out = np.zeros((batch_size, q_len, hidden_size))
l = np.zeros((batch_size, q_len))
m = np.zeros((batch_size, q_len))
m.fill(-np.inf)

Flash Attention~2_空间复杂度_11

for i in range(Tr):

Flash Attention~2_人工智能_12

m_new = np.maximum.reduce([mi, mij])
l_new = np.exp(mi - m_new) * li + np.exp(mij - m_new) * lij

Flash Attention~2_换出_13

temp = li[..., np.newaxis] * np.exp(mi - m_new)[..., np.newaxis] * oi + np.exp(mij - m_new)[..., np.newaxis] * np.matmul(pij, vj)
temp /= l_new[..., np.newaxis]
self.write(out, temp, i, i + 1, self.row_block_size)

Flash Attention~2_空间复杂度_14

(14) 循环结束;

(15) 循环结束;

Flash Attention~2_权重_15

return out

注:上述代码只是Flash Attention原文算法1的直观实现,可能和底层C++实现在细节上存在一些出入。官方实现请请翻阅Flash Attention的原始仓](https//github.com/Dao-AILab/flash-attention)。

为验证上述 Flash Attention 实现的正确性,我们可以通过对比上述实现的 Flash Attention、“Transformer 简介”一节实现的 self_attention 函数以及 PyTorch 官方库的 nn.functional.scaled_dot_production 函数的运算结果(单元测试的完整代码见github仓库(https//gist.github.com/xiabingquan/a4a9a743f97aadd531ed6218be20afd2))。单元测试通过。

Flash Attention~2_空间复杂度_16

图6. Flash Attention单元测试结果

实验效果

为验证Flash Attention在实际训练场景中的有效性,Flash Attention论文原文对比了分别基于原始attention和Flash Attention的BERT和GPT2模型的训练时间以及模型性能等,还基于Flash Attention做了长上下文语言模型建模相关实验,此处略过,请参考论文原文(https//arxiv.org/abs/2205.14135)。

这里贴一些Flash Attention仓库(https//github.com/Dao-AILab/flash-attention)中的图,让大家对Flash Attention的时间加速比以及空间节省情况有一个更直观的认识。

Flash Attention~2_权重_17

图7. Flash Attention加速情况

Flash Attention~2_权重_18

图8. Flash Attention节省显存情况

注:上述为A100的测试结果,不代表其他GPU的情况。

总结

本文介绍了 Flash Attention,一种相比于原始attention运算速度更快、更节省显存的精确注意力算法。 

Flash Attention 的特点在于尽量减少 GPU 的 HBM 和片上 SRAM 之间的数据交换,从而达到加速运算以及节省显存的目的。

Flash Attention 的核心方法是 tiling 和 recomputation。其中 tiling 递推式地计算 softmax,避免了计算整个注意力权重矩阵,而 recomputation 则基于前向运算中的 tiling 保存的某些中间变量,在反向传播时重新计算注意力权重矩阵。

自 PyTorch 2.0 起,Flash Attention已经集成到 PyTorch 官方库中。使用者可以通过 torch.nn.functional.scaled_dot_prodoction 进行调用。

当前,Flash Attention还在迭代中,Flash Attention-2(https//tridao.me/publications/flash2/flash2.pdf) 已经推出。

参考

  1. https://arxiv.org/abs/1706.03762
  2. https://arxiv.org/abs/1409.3215
  3. https://arxiv.org/pdf/2112.05682
  4. https://arxiv.org/pdf/1805.02867
  5. https://ieeexplore.ieee.org/document/8980322
  6. https://arxiv.org/abs/1805.02867