如果你要训练一个模型大概会考虑哪些因素?
- 模型多大?参数
- 占用显存多少,能不能装的下
- 我需要多少算力来支撑
本文就针对一个标准的Transfomer模型的套路和大家简单说一下
为了后文大家看算式明白,我们先约定一下每个变量代表的意义
- L: Transfomer有多少层
- H:代表两个意义,第一个意义是hiddensize的维度,第二个就是token被embedding以后的维度,这两值本来也相等
- h: 小写的h代表多头注意力的数量,即有几个attention 头
- B:batchsize
- S:序列的长度,比如GPT 2K,LLama2 4K,就是这个东西
- V: 词表里词的数量
然后我们逐一看一下我们都要算哪些模块
如上图所示是一个标准的Transfomer架构,但是这东西我讲过除了T5和一些特定的网络以外大家都不这么用了,目前的主流是Causal-decoder only,也就是做CLM的任务,自回归的生成,因为现在的LLM主流任务主要是做生成的,对这个知识点感兴趣的读者,可以先移步:小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(1) (qq.com)
所以我们也主要考虑的就是Causal-decoder-only的架构
架构,LLama系套壳全是这样的,魔改的其实也加不了太多东西,所以以下我要讲的算法可以认为是通用的
Self-Attention层:
Self-Attention的时候你拿到的形状是[B,S,H]张量
我们看一下Self-Attention层都干啥
首先要生成4个权重矩阵吧?
即
、
每个权重矩阵的形状是[H,H]
前3个权重矩阵,分别负责生成QKV和输入的embedding向量[B,S,H],要各自都做一次点积[B,S,H]*[H,H], 需要计算BSH^2次,合计3BSH^2次,生成的张量形状都为[B,S,H]
然后进入到QKV的环节,首先是Q*K的转置,除以K的维度开方然后softmax
完整式子
因为K的维度就等于H,所以可以写成
),我们现在把多头注意力的机制考虑进来
- h为多头数量
- H'为每个多头分到的head_dim
看下面代码就明白了
class LongLlamaAttention(nn.Module):
def __init__(self, config: LongLlamaConfig, mem_config: Optional[LongLlamaMemConfig] = None):
super().__init__()
self.config = config
# 隐层的维度,4096
self.hidden_size = config.hidden_size
# attention 中 head 的数量,32
self.num_heads = config.num_attention_heads
# attention 中每个 head 的维度,4096 // 32 = 128
self.head_dim = self.hidden_size // self.num_heads
# 位置向量长度,2048
self.max_position_embeddings = config.max_position_embeddings
# cache中缓存的 stentence 最大长度,2048
self.max_cache = self.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# 生成 query、key、value 时,用到的线性映射层
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# 旋转位置编码
self.rotary_emb = LongLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
# memory attention 相关参数
self.mem_config = mem_config
运算量等于[B,h,S,H’]*[B,h,H',S], 运算量为BHS^2(H'和h最后还是合成H,看上边的代码),形状为[B,h,S,S]
这一步做完了,我们要计算上一步的求解和V矩阵点积的结果
即[B,h,S,S]*[B,h,S,H'],计算量 BhS^2H'即BHS^2,计算后的形状为[B,h,S,H']
Attetion最后一步就要过线性层Wo,把多头给降回单头
算力计算是[B,h,S,H']*[H,H],因为h*H'=H,所以化简为[B,S,H]*[H,H], 即BSH^2,形状为[B,S,H](这把就完成了进来啥形状,出来啥形状了)
当然因为残差网络的存在,所以我们还要加一个input(X)进来,这一次的加法可以忽略掉了
我们知道神经网络的计算都是一次加法,一次乘法,(其实加法是n-1,就别那么矫情了,就都按2次算好弄)
那么整个Self-Attention阶段的算力要求,对于每一个模型参数为
BSH^2+BHS^2+BHS^2)
即:
本节完