如果你要训练一个模型大概会考虑哪些因素?

  •      模型多大?参数
  •      占用显存多少,能不能装的下
  •      我需要多少算力来支撑


        本文就针对一个标准的Transfomer模型的套路和大家简单说一下

     

LLM 参数,显存,Tflops? 训练篇(1)_点积

      为了后文大家看算式明白,我们先约定一下每个变量代表的意义

  •        L: Transfomer有多少层
  •        H:代表两个意义,第一个意义是hiddensize的维度,第二个就是token被embedding以后的维度,这两值本来也相等
  •        h: 小写的h代表多头注意力的数量,即有几个attention 头
  •        B:batchsize
  •        S:序列的长度,比如GPT 2K,LLama2 4K,就是这个东西
  •        V:  词表里词的数量


       然后我们逐一看一下我们都要算哪些模块

       如上图所示是一个标准的Transfomer架构,但是这东西我讲过除了T5和一些特定的网络以外大家都不这么用了,目前的主流是Causal-decoder only,也就是做CLM的任务,自回归的生成,因为现在的LLM主流任务主要是做生成的,对这个知识点感兴趣的读者,可以先移步:小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(1) (qq.com)

       所以我们也主要考虑的就是Causal-decoder-only的架构

      

LLM 参数,显存,Tflops? 训练篇(1)_点积_02

架构,LLama系套壳全是这样的,魔改的其实也加不了太多东西,所以以下我要讲的算法可以认为是通用的

      Self-Attention层:

Self-Attention的时候你拿到的形状是[B,S,H]张量

      我们看一下Self-Attention层都干啥

      首先要生成4个权重矩阵吧?

      即

LLM 参数,显存,Tflops? 训练篇(1)_Self_03


        每个权重矩阵的形状是[H,H]

        前3个权重矩阵,分别负责生成QKV和输入的embedding向量[B,S,H],要各自都做一次点积[B,S,H]*[H,H], 需要计算BSH^2次,合计3BSH^2次,生成的张量形状都为[B,S,H]

      然后进入到QKV的环节,首先是Q*K的转置,除以K的维度开方然后softmax

      完整式子

LLM 参数,显存,Tflops? 训练篇(1)_权重_04


       因为K的维度就等于H,所以可以写成

   

LLM 参数,显存,Tflops? 训练篇(1)_Self_05

),我们现在把多头注意力的机制考虑进来

  •         h为多头数量
  •         H'为每个多头分到的head_dim

看下面代码就明白了

class LongLlamaAttention(nn.Module):    
    def __init__(self, config: LongLlamaConfig, mem_config: Optional[LongLlamaMemConfig] = None):
        super().__init__()
        self.config = config
        # 隐层的维度,4096
        self.hidden_size = config.hidden_size
        # attention 中 head 的数量,32
        self.num_heads = config.num_attention_heads
        # attention 中每个 head 的维度,4096 // 32 = 128
        self.head_dim = self.hidden_size // self.num_heads
        # 位置向量长度,2048
        self.max_position_embeddings = config.max_position_embeddings
        # cache中缓存的 stentence 最大长度,2048
        self.max_cache = self.max_position_embeddings

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        # 生成 query、key、value 时,用到的线性映射层
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        # 旋转位置编码
        self.rotary_emb = LongLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
        # memory attention 相关参数
        self.mem_config = mem_config


      运算量等于[B,h,S,H’]*[B,h,H',S], 运算量为BHS^2(H'和h最后还是合成H,看上边的代码),形状为[B,h,S,S]

      这一步做完了,我们要计算上一步的求解和V矩阵点积的结果

LLM 参数,显存,Tflops? 训练篇(1)_点积_06

      即[B,h,S,S]*[B,h,S,H'],计算量 BhS^2H'即BHS^2,计算后的形状为[B,h,S,H']

      Attetion最后一步就要过线性层Wo,把多头给降回单头

LLM 参数,显存,Tflops? 训练篇(1)_点积_07

       算力计算是[B,h,S,H']*[H,H],因为h*H'=H,所以化简为[B,S,H]*[H,H], 即BSH^2,形状为[B,S,H](这把就完成了进来啥形状,出来啥形状了)

       当然因为残差网络的存在,所以我们还要加一个input(X)进来,这一次的加法可以忽略掉了

LLM 参数,显存,Tflops? 训练篇(1)_Self_08

       我们知道神经网络的计算都是一次加法,一次乘法,(其实加法是n-1,就别那么矫情了,就都按2次算好弄)

      那么整个Self-Attention阶段的算力要求,对于每一个模型参数为

BSH^2+BHS^2+BHS^2)

      即:

LLM 参数,显存,Tflops? 训练篇(1)_点积_09

       本节完

     

LLM 参数,显存,Tflops? 训练篇(1)_权重_10