上篇文章链接 LLM 参数,显存,Tflops? 训练篇(4) (qq.com)
上上篇文章链接LLM 参数,显存,Tflops? 训练篇(3) (qq.com)
上上上篇文章链接 LLM 参数,显存,Tflops? 训练篇(2) (qq.com)
上上上上篇文章链接 LLM 参数,显存,Tflops? 训练篇(1) (qq.com)
为了后文大家看算式明白,我们先约定一下每个变量代表的意义(和前一篇的命名方式一样)
- L: Transfomer有多少层
- H:代表两个意义,第一个意义是hiddensize的维度,第二个就是token被embedding以后的维度,这两值本来也相等
- h: 小写的h代表多头注意力的数量,即有几个attention 头
- B:batchsize
- S:序列的长度,比如GPT 2K,LLama2 4K,就是这个东西
- V: 词表里词的数量
- H'为每个多头分到的head_dim
上篇文章主要是讲在训练阶段,静态显存的占用,这块如果用FP16或者BF16的情况下,一般占用的显存为模型参数的2倍
静态显存占用的最基本逻辑就是load模型参数,但是别忘了在训练的过程中,同时还要保存另外两个重要的东西:
- 梯度
- 优化器
在一次用AdamW和混合精度训练的Epcho里,每一个模型参数,需要占用: - 2byte的模型静态参数权重(以16bit存储)
- 2byte的模型更新参数权重(以16bit存储)
- 2byte的梯度(以16bit存储)
- 2byte的梯度更新(以16bit存储)
- 4byte的一阶动量优化器更新(以32bit存储)
- 4byte的二阶方差优化器更新(以32bit存储)
整体的分布如上图所示,所以在训练的过程中,一个模型参数需要占用16bytes的内存。
除了第一项,其他后5项严格来说都不能算是静态占用。
除了训练时load的以上各种参数相关的权重以外,最终要的是输入模型进行训练的token的batchsize长度和单个训练的seq_number,这两个值会直接影响到我到底要load多少数据,这部分数据会和刚才讲的参数占用的显存一起构成训练过程中的显存消耗。
而这些就是我们一会要讲一下在训练过程中最消耗显存的部分,这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。这里的激活不包含模型参数和优化器状态,但包含了dropout操作需要用到的mask矩阵,不是激活函数的那个Activation。
同时也要有取有舍,比如:
只考虑中间激活大头,忽略掉一些小的,例如对于layer normalization,计算梯度时需要用到层的输入、输入的均值μ和方差(σ的平方),输入包含了 BSH 个元素,而输入的均值和方差分别包含了 BS个元素。由于通常H是比较大的(好几千数量级),所以 BSH≫BS。因此,对于layer normalization,中间激活近似估计为 BSH ,而不是 BSH+2BS (只计算了层的输入,不考虑均值和方差那部分输入的计算了)
因为我们假设是用混合精度来训练的,Activation值的存取是以半精度来保留的,也就是2bytes,dropout的mask矩阵特殊,它每个元素占用1个bytes。
现在我们一起逐层分析一下这部分Activation显存怎么计算。
先分析Self-Attention层:
1-x作为新进入的sequence经过这3个矩阵(线性层)就会出来QKV的值,这个3个矩阵的共同输入x,就是我们所谓的Activation
x的输入形状就是[B,S,H], 所以是2BSH bytes占用(16bit存着)
2-对于QK的矩阵乘和Softmax的计算,需要保留Q和K两个矩阵,都是[B,S,H],所以两个加起来是4BSH bytes的内存占用,Softmax又要保留Q*K的转置,考虑到多头的因素,这块需要2BaS^2的内存占用(这块和前面计算Tflops逻辑类似)
3-当计算完Softmax,此时会进行dropout的操作,为了dropout丢弃掉一部分参数,需要用一个mask矩阵来做这个选择,这个mask矩阵形状自然需要和Q*K的转置的矩阵相同,但是因为它只有1byte每个参数,所以它的占用内存为BaS^2
4-计算QKV,需要保存Softmax的值,既2BaS^2,与此同时还要保存V的值2BSH,所以这一步要保存2BaS^2+2BSH
5-在Attetion操作的最后,1要保存Wo的输出映射,这一部分和之前的矩阵一样都是2BSH,同时还会做一次dropout的动作,这里的dropoutmask矩阵和Wo相同也为BSH,所以这一步需要占用3BSH bytes的显存
综上,在Self-Attention这一步占用的中间显存Activation是11BSH+5BaS^2
再分析FFN/MLP层:
FFN层的计算逻辑就如上图所示,我们还是按照标准Transformer的网络来进行推算。
1-第一个线性层保留的输入Activation就是2BSH
2-然后要先进入到一个激活函数,为了保留这个激活函数,需要4倍的显存,既为8BSH
3-第二个线性层需要保留激活函数的输入来进行后面的计算,所以要占用8BSH的显存
4-最后的dropout mask矩阵,照例 BSH
综上,所有的MLP层需要的Activation显存占用需要19个BSH
由于在训练的过程中最大的数字(比如大于1000)基本都是S, seq_number和H, hidden_size占用,这些数字占大头,所以我们也可以简单的判断,MLP层在中间态显存占用的比例要大于Attention层
另外单层Transformer一般都有2层的Layer Norm, 这两层在文章开头已经讲了,会占用4BSH
所以整个Transformer 单层需要占用34BSH+5BaS^2
整个网络就占用:
在掌握了所有的训练时显存占用单位,我们可以尝试计算一下,比如还是用前文提到的Llama-65B的模型来做评估。
Llama-65B的一些训练参数:
- 64 headers
- 80 layers
- 2048 Seq_number
- 4M Batch_size
回到之前的算式,我们现在进行代入计算
- 2byte的模型静态参数权重 = 130G
- 2byte的模型更新参数权重 = 130G
- 2byte的梯度(以16bit存储)= 130G
- 2byte的梯度更新(以16bit存储)= 130G
- 4byte的一阶动量优化器更新(以32bit存储)= 260G
- 4byte的二阶方差优化器更新(以32bit存储)= 260G
关于Activation的占用,我们为了看一下区别,假设集中场景: - 如果Bath_size为1的话
Activatinotallow=80*(34*1*2048*8192+5*1*64*2048*2048)
所以batch_size=1的情况下要消耗150G内存的中间状态
如果是128呢,就是19.5T了...
综上,基本上Activation决定了你需要的卡的数量,一般情况下有三种选择:
1- 扩大卡的数量,这个和资金成本还有通信代偿挂钩
2- 降低batch_size和seq_number的数值, 这个会影响训练效率和最终的结果
3- 采用梯度重算机制来节省内存,本质上是时间换空间
本系列结束,后面会讲推理的算力和内存机制