这个完全是基础知识啊~~ 哪不对 大佬们帮忙指出啊
# CUDA 矩阵乘法优化手段详解
Naive 实现的分析:到底差在哪里?
笔者面试过不少具有 CUDA 编程经验的校招同学,当提问使用 CUDA 编写一个 SGEMM Kernel 的时候,通常会获得这么一个答案:
__global__ void matrixMul(const float *A, const float *B, float *C,
int M, int N, int K) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int ty = blockIdx.y * blockDim.y + threadIdx.y;
if(ty < M && tx < N) {
float c = 0;
for(int i = 0; i < K; ++i){
c += A[ty * K + i] * B[i * N + tx];
}
C[ty * N + tx] = c;
}
}
这样一个 Naive 的 Kernel 当然不是笔者所期待的,因为这个 Kernel 的性能基本可以断定连 cublas 的 1/10 都不到,显然不符合我们追求高性能的需求。那么这个 Naive 的实现究竟差在哪呢?
分析代码我们可以看到,计算一次 FMA(乘累加)之前需要读一次 A 和读一次 B,众所周知,读取 Global Memory 的代价很大,通常都需要几百个 cycle(时钟周期),而计算一次 FMA 通常只需要几个 cycle,大量的时间被花费在了访存上。也许有思维活络的同学立马想到,可以将 A 和 B 矩阵先搬运到 Shared Memory(SM 中低延迟的 on-chip memory,block 内线程共享,附 NVIDIA GPU 内存结构图)中降低访存的开销,这的确是一个很好的思路,但是这只能将访存代价从几百 cycle 降低到几十 cycle,并不改变问题的本质。问题的关键在于主体循环由两条 Load 指令与一条 FMA 指令构成,计算指令只占总体的 1/3,计算访存比过低,最终导致了访存延迟不能被隐藏,从而性能不理想。
让我们打开思路,若一个 thread 并不只计算一个结果,而是计算 4x4 个结果,并且使用 Shared Memory 优化,Hot Loop 会是什么样呢,伪代码如下所示:
float c[4][4] = {{0}};
float a_reg[4];
float b_reg[4];
for(int i = 0; i < K; i += TILE_K){
__syncthreads();
// transfer tile from global mem to shared mem
load_gmem_tile_to_smem(A, i, smemA);
load_gmem_tile_to_smem(B, i, smemB);
__syncthreads();
#pragma unroll
for(int j = 0; j < TILE_K; ++j) {
// load tile from shared mem to register
load_smem_tile_to_reg(smemA, j, a_reg);
load_smem_tile_to_reg(smemB, j, b_reg);
// compute matrix multiply accumulate 4x4
mma4x4(a_reg, b_reg, c);
}
}
分析可以得出从 smemA 读取到寄存器 a_reg 中,需要进行 4 次访存操作,B 同理,那么主体的计算访存指令比例变成了 16/8,相对于之前的情况,计算指令的占比大大提高了。足够大的计算访存比能提升计算单元的利用率,并能起到隐藏访存延迟的作用。我们可以进一步提升计算访存比,从而使得 kernel 的性能接近理论峰值。
矩阵分块与资源分配
显然我们不能只使用一个 block 计算一个超大矩阵,这样会造成大量 SM(Streaming Multiprocessor)的闲置浪费,这就需要对矩阵进行分块计算,如下图所示:
不同的分块大小在不同 shape 的矩阵乘法应用上性能各有优劣,本文选取 128x128 的分块举例。
从上一小节我们可以看到,提升计算访存比有很大的好处,那么计算访存比可以无限提升吗,答案是否定的。因为要提升计算访存比,单个 thread 就需要计算一个更大的块,这就需要更多的寄存器,但寄存器的个数是有限的。以 Turing 架构的 GPU 为例,单个 SM 的寄存器总量为 65536,因为指令编码的限制,单个 thread 能使用的最大寄存器个数为 255,并且寄存器个数并不是用得越多越好。这里需要引入一个 Occupancy(占用率)的概念,Occupancy 是指每个 SM 中活动线程束(Warp)数量与最大并发线程束数量的比值,高的 Occupancy 不一定意味着高性能,但可以通过切换执行 Warp 来起到一定隐藏延迟的作用。而每个 SM 中的 Active Warp 数量,取决于 block 使用的资源数量,具体为每个线程使用的寄存器个数与 Shared Memory 用量。Occupany可通过 CUDA Toolkit 中提供的 CUDA_Occupancy_Calculator.xls 工具获得。
考虑一个 block 计算 128x128 的分块,若每个线程计算 128 个结果,需要的 block size 为 128,单个线程需要 128 个寄存器储存计算结果,加上所需的 Gmem to Smem,Smem to Reg 等一些所需的寄存器,大概共需要至少 180 多个,计算 Occupany 可知此时的 Active Warp 数只有 8,Occupany 为 25%;若设置 block size 为 256,则每个线程仅需计算 64 个结果,调整寄存器和 Shared Memory 的使用量并观察 Occupany,可知若每个线程只使用 128 个寄存器,block 内的 Shared Memory 使用量限制在 32K,Active Warp 数可以达到 16,是一个更优的选择:
并且此时的配置计算访存比可以达到 64/4(使用向量读取),已经足够隐藏访存延迟。
极致的访存优化
通常情况下,在选取了合适的 block 资源配置,利用 Shared Memory 降低访存延迟,做好循环展开之后,SGEMM Kernel 的性能已经能达到一个不错的水平(80% cublas),但这并不是我们旅程的终点。首先,我们可以使用向量读取指令LDS.128优化 Shared Memory 访问(对应 float4 数据类型),这能大幅减少访存指令的数量,进一步提升计算访存比,由此我们需要将 A 矩阵存入 smemA 之前做一次转置:
同时,我们的 kernel 为 256 个线程计算 128x128 的分块,为了能够合并访问 Shared Memory,我们将 256 个线程划为二维,令:
int tx = threadIdx.x % 16;
int ty = threadIdx.x / 16;
并按照如下方式向量读取 Shared Memory 中的数据:
最终单个线程计算 2x2 个 4x4 的结果,结果布局如图所示:
并且通过 micro benchmark 可以探测出,Turing(Tesla T4) 的 Global Memory 的访存延迟约 300 cycle,Shared Memory 的访存延迟在约 30 cycle,需要充分利用 Prefetch 的思想,隐藏 Global Memory 读入中间寄存器、将来自 Global Memory 的数据块写入 Shared Memory、从 Shared Memory 中读出数据块的访存延迟,以免计算单元因为 stall 而空闲太久,最终的伪代码如下所示:
#define TILE_K 16
__shared__ float4 smemA[2][TILE_K * 128 / 4];
__shared__ float4 smemB[2][TILE_K * 128 / 4];
float4 c[8][2] = {{make_float4(0.f, 0.f, 0.f, 0.f)}};
float4 ldg_a_reg[2];
float4 ldg_b_reg[2];
float4 a_reg[2][2];
float4 b_reg[2][2];
// transfer first tile from global mem to shared mem
load_gmem_tile_to_reg(A, 0, ldg_a_reg);
load_gmem_tile_to_reg(B, 0, ldg_b_reg);
store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[0]);
store_reg_to_smem_tile(ldg_b_reg, 0, smemB[0]);
__syncthreads();
// load first tile from shared mem to register
load_smem_tile_to_reg(smemA[0], 0, a_reg[0]);
load_smem_tile_to_reg(smemB[0], 0, b_reg[0]);
int write_stage_idx = 1; //ping pong switch
do {
i += TILE_K;
// load next tile from global mem
load_gmem_tile_to_reg(A, i, ldg_a_reg);
load_gmem_tile_to_reg(B, i, ldg_b_reg);
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for(int j = 0; j < TILE_K - 1; ++j) {
// load next tile from shared mem to register
load_smem_tile_to_reg(smemA[load_stage_idx], j + 1, a_reg[(j + 1) % 2]);
load_smem_tile_to_reg(smemB[load_stage_idx], j + 1, b_reg[(j + 1) % 2]);
// compute matrix multiply accumulate 8x8
mma8x8(a_reg[j % 2], b_reg[j % 2], c);
}
if(i < K) {
// store next tile to shared mem
store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[write_stage_idx]);
store_reg_to_smem_tile(ldg_b_reg, 0, smemB[write_stage_idx]);
// use double buffer, only need one sync
__syncthreads();
// switch
write_stage_idx ^= 1;
}
// load first tile from shared mem to register of next iter
load_smem_tile_to_reg(smemA[load_stage_idx ^ 1], 0, a_reg[0]);
load_smem_tile_to_reg(smemB[load_stage_idx ^ 1], 0, b_reg[0]);
// compute last tile mma 8x8
mma8x8(a_reg[1], b_reg[1], c);
} while (i < K);
store_c(c, C);
注:此处偷懒假设了 M、N、K 都是 4 的倍数,若非 4 的倍数则 Global Memory 不能使用 float4 进行读取,结果也不能用 float4 进行写回,而且为了合并写回,需要通过 Shared Memory 交换 warp 内的结果,保证每个 warp 执行一条 Store 指令能够写回一片连续的内存空间。
至此我们获得了一个充分优化的 SGEMM Kernel。另外 Ampere GPU 新增了LDGSTS指令,数据块从 Global Memory 到 Shared Memory 的过程不需要经过中间寄存器,可以进一步的优化 SGEMM 的性能。
性能对比
为了避免 cublas 选取到 split K 的 Kernel,我们将 K 固定为 1024,取 M, N = 2048, 4096, 8192 和 16384 作为测试用例,对比了上述 SGEMM Kernel 与 cublas 的性能(测试 GPU 为 Tesla T4,锁定核心频率为 1100):
可以看到所实现的 SGEMM Kernel 达到了 cublas 平均 97.5% 的性能。
超越 cublas:使用 SASS 调优 Kernel
到这里,可能有同学依然有一个疑问,我们似乎把所有能想到的优化手段都用上了,为什么写出来的 CUDA C Kernel 依然离 cublas 有一定的差距,答案是 cublas 所使用的 kernel 中有一大部分并不是通过 nvcc 编译的 CUDA Kernel,而是使用 NVIDIA GPU 的汇编语言(Shader Assembly,简称 SASS)编写的深度调优版本。
尽管 nvcc 编译器在不断的进步,特别是 CUDA 11 中的 nvcc,所编译的 Kernel 与手工汇编优化版本之间的差距已大幅缩小,但仍然无法完全避免寄存器 Bank conflict 的影响以及充分利用寄存器的 Reuse Cache(这两个概念下面会进行详细的介绍),使得差距仍然存在。即使 PTX 这样的伪汇编语言,也无法精确控制寄存器的分配,和 CUDA C 面临着一样的困境。
所以为了充分挖掘 GPU 的性能极限,需要对 GPU 指令和寄存器进行精确控制,就必须交由 GPU 原生汇编语言 SASS 完成。这方面已经有了很多研究,如出自 Citadel 的深入研究 NV GPU 架构的 Dissecting the NVidia XXX GPU architecture via microbenchmarking 系列论文,这一系列文章对底层架构做了系统的测试、分析和总结,虽然其中某些结论可能并不准确,但总体来讲有很高的参考价值。同时催生了不少开源汇编器如 KeplerAs、maxas(最成熟,影响深远)、turingas 和 CuAssembler 等一系列开源 SASS 汇编器,使得使用 SASS 编写高性能 Kernel 变成了可能。
寄存器 Bank conflict
我们知道 Shared Memory 有 Bank conflict,而寄存器的 Bank conflict 也是类似的概念。NVIDIA GPU 每个 SM 有独立的 Register File,而 Register File 被分为若干个 Bank,以 Maxwell 为例,若一条指令所需的源寄存器有 2 个以上来自于同一 Bank,则会产生 conflict,指令会相当于重发射,浪费一个 cycle。Maxwell/Pascal 的 Register File 的 Bank 数为 4,寄存器的id%4即为该寄存器的所属 bank(如 R0 属于 Bank 0,R5 属于 Bank 1),FFMA R1, R0, R4, R1这样的指令就会产生寄存器 Bank conflict。而 Turing 架构做了改进,Register File 被分为 2 个 Bank,每个 Bank 有 2 个 Port,若非三个源寄存器 id 同奇偶则不会产生冲突,大大缓解了寄存器 Bank conflict。
maxas 中的 Maxwell SGEMM SASS Kernel 为了缓解寄存器 Bank conflict,就对参与 FFMA 计算的寄存器做了精巧的分配(参考 maxas 的 SGEMM 文档),如下图所示:
经过对 C 的巧妙排布,寄存器 Bank conflict 大大减少,但依然无法完全避免(如上图中黑框标识的部分,A/B 所使用的寄存器会产生 Bank conflict),这部分冲突就需要用到寄存器 Reuse 来消除。
Register Reuse
寄存器 Reuse 是 NVIDIA 为了缓解寄存器 Bank conflict 的问题,在 Maxwell 开始引入的一种机制,NVIDIA 在读取指令操作数的 Collector 单元加入了寄存器的 Reuse Cache。Reuse Cache 是只读的,指令获取 Operand 是否通过此 Cache 由该指令的 control code(maxas 的 control code wiki中有详细的介绍)所指定,使用 cuobjdump 反汇编一些 Kernel 可以发现一些寄存器后有 .reuse的 flag,即表示该寄存器从 Reuse Cache 而非 Register File 中取值,从而消除寄存器 Bank conflict:
# Maxwell GPU
FFMA R2, R64.reuse, R73, R2; # R64 进入 Reuse Cache
FFMA R3, R64.reuse, R72, R3; # R64 从 Reuse Cache 中获取,避免与 R72 冲突
但是使用 .reuse需要满足一定条件(寄存器将被改写前不能设置 .reuse),胡乱设置 reuse flag 会有可能获取的是历史值,造成计算错误,根据笔者的理解,.reuse 更像是使该寄存器的值在 Reuse Cache 中 hold 住的标识。nvcc 编译 CUDA Kernel 也会使用 Reuse Cache 去规避一些寄存器 Bank conflict,但是因为寄存器分配及指令排布的原因,Reuse 的利用率并不高,反汇编我们刚才写的 SGEMM Kernel,对主循环的所有 FFMA 指令做个统计,可以发现 Reuse Cache 仅达到 20% 左右,而 maxas 的 SASS Kernel 通过设计使得 Reuse 的利用率可以达到 49%。
最终通过 SASS 精细调优的 SGEMM Kernel 的性能可以全面超越 cublas,感兴趣的同学们可以自行编译 maxas 中的 SGEMM Kernel 在 Maxwell 或者 Pascal GPU 上进行测试。最后,虽然使用 SASS 能充分挖掘 GPU 的性能,但面临有三大问题:1. 第三方 NV GPU 汇编器依赖于对 GPU 架构的逆向研究,可能因为没有探究到全部的硬件底层细节而存在未知的 BUG;2. 汇编 Kernel 难于开发,更难于调试;3. NV 每一代 GPU 的 ISA(指令集)都不尽相同,需要不断开发对应的汇编器和汇编 Kernel。正因为这几大问题的存在,使得使用 SASS 编写 Kernel 是个费时费力的工作,除非有追求极致性能的需求,否则不建议轻易尝试。
GEMM 的延伸:优化卷积运算
我们都知道优化卷积运算可以通过 im2col 将卷积映射为矩阵乘法来实现,对于上述 SGEMM Kernel,只需要将 Global Memory 的数据搬运到 Shared Memory 这一过程稍作修改,由对应位置的映射变为 im2col 映射,SGEMM Kernel 就摇身一变成为了计算 Conv 的 Kernel,这即是 cudnn 卷积运算的 Implicit Gemm 算法。而在 im2col 过程中,若直接计算指针的偏移量的话,会引入大量的整数除法和取余运算,这是一笔不小的开销,所以可以将地址的偏移量在 host 端预先计算好,作为 param 传入 kernel 中,则可以在需要时从常量内存中读取,避免整数除法和取余,实现 Implicit Precomp Gemm。
总结
本文详细介绍了如何编写一个高效率的 CUDA SGEMM Kernel,并且介绍了使用 SASS 编程这一极限优化性能的手段,并稍稍延伸展开了通过 Implicit Gemm 优化卷积运算的思路,希望可以给予有志于极致挖掘硬件性能的同学们一定的启发
# 自己写的CUDA矩阵乘法能优化到多快?
1. Introduction
最近研究了一下Nvidia GPU搭载的Tensor Core,开始手写半精度浮点类型(half or fp16)的矩阵乘法算子(c = a * b,其中a、b、c均为fp16类型),并尝试将其优化到cublas的性能水平。
本文源代码参见nicolaswilde/cuda-tensorcore-hgemm (github.com)。
下图是我在RTX3090上测试得到的我自己手写的几个kernel和CUBALS_GEMM_DFALT在M = N = K(256 ~ 16384)下的性能对比,其中加粗蓝色是cublas、加粗绿色是我优化的最终版本、四条灰色曲线是优化的几个中间版本。可以看到,myHGEMMAlignedV5性能基本超过CUBALS_GEMM_DFALT,实现了优化目标。
myHGEMM vs CUBLAS
RTX3090共有82个SM,每个SM有4个Tensor Core,每个Tensor Core有256 FLOP/Cycle的fp16算力,实测RTX3090最高运行在1.9GHz左右,因此其fp16峰值算力约为82 * 4 * 256 * 1.9G ~ 159 TFLOPS。我的HGEMM Kenel最高跑到了131 TFLOPS,约为峰值算力的82%。
关于cublas,cublas中的cublasGemmEx可以指定参加矩阵乘法运算的数据类型,并且可以指定40多种算法,下图是cublas在M = N = K(256 ~ 16384)的性能表现,最高可以达到126 TFLOPS。从性能曲线上看我严重怀疑这40多种算法最后调用的是相同的kernel可是我又没有证据:
CUBLAS HGEMM Performance
作为CUDA矩阵乘法优化系列的第二篇,本文将重点关注Tensor Core的相关行为。基础的CUDA编程方法和基于CUDA Core的单精度矩阵乘法算子优化请首先查看:
三个月前这篇文章的测试平台还是四年前入手的GTX 1060,如今鸟枪换大炮,本文使用RTX 3090进行测试,以尝试一下Ampere这代最新架构的GPU。别问我RTX 3090多少钱买的[doge]根本买不起,从某云GPU平台租的2块钱/小时,前后花了我100多大洋...
言归正传,从Tensor Core讲起。对Tensor Core已经有所了解的同学们可以直接跳到第四节。\
2. Tensor Core
Nvidia从Volta这代GPU开始引入Tensor Core,其目的是用于加速以AI推理和训练为代表的、以大规模矩阵乘法或类矩阵乘法为典型负载的这么一类应用。毕竟CUDA Core的运算能力有限,在矩阵乘法这种典型的计算密集型的负载上会有大量的访存带宽浪费,Tensor Core的加入就能够在计算矩阵乘法时利用起GPU动辄大几百GB/s的内存带宽。
下图是Ampere A100中每个SM的结构图,可以看到Tensor Core实际上就是SM Block中的一个功能部件,同原本的CUDA Core处于相同地位。当然也有些许的区别:例如INT32的向量指令,一个Warp中32个线程分别在16个INT32的运算部件中执行;而Tensor Core指令则是32个线程合作,取32个线程的操作数共同在一个Tensor Core中完成矩阵乘法操作。
A100 SM Architecture
Volta架构中的Tensor Core这里略去不讲,个人猜测是由于物理设计的原因,Volta的Tensor Core在寄存器中还要分Thread Group,同一数据还要存储两次,看起来非常不简洁,关心Volta Tensor Core的可以读一读ISPASS2019的这篇论文《Modeling Deep Learning Accelerator Enabled GPUs》。
从Turing这一代开始,矩阵元素在寄存器中的摆放方式就非常规整,下图是一个8 - 8 - 128bit的矩阵乘法的示意图。Turing Tensor Core支持(u)int8和fp16的数据类型,Ampere Tensor Core进一步支持了bf16和tf32数据类型,还有一些不常用的INT4、INT2、INT1。以本文中测试的half(也就是fp16)为例,下图中这个最基本的Tensor Core操作计算了一个8x8x8的矩阵乘法。
Matrix 8-8-128bit Layout
Turing Tensor Core为了减少指令数目并缓解寄存器压力,一条Tensor Core指令可以支持16x8x8的fp16的矩阵乘法,对应的SASS指令也就是HMMA.1688,其寄存器排布如下图所示:
Matrix 16 - 8 - 128bit Layout
Ampere Tensor Core一条Tensor Core指令可以支持16x8x16的fp16的矩阵乘法,因此我们后续反汇编查看到指定compute capability = 86的SASS代码中清一色的都是HMMA.16816指令了。
Matrix 16 - 8 - 256bit Layout
从Volta第一次引入Tensor Core开始,到Ampere的Tensor Core,基本的演进除了数据类型的增加,更重要的是峰值性能的增加。V100一个SM Block中的两个Tensor Core每拍一共可以计算128个乘累加;而A100一个SM Block中只有一个Tensor Core,每拍可以计算256个乘累加,也就是全流水8拍执行一条HMMA.16816;然而!Sadly,rtx 30系列显卡的Tensor Core竟然阉割了,每拍只有128个乘累加,全流水16拍执行一条HMMA.16816,这样一来rtx 3090一共82个SM,每个SM有4个Tensor Core,标称Boost Clock 1.695GHz,因此峰值性能为82 * 4 * 256 * 1.695G ~ 142 TFLOPS。
ensor Core Evolution
3. Tensor Core的编程方法
3.1 C++ API
CUDA C++中包装了Tensor Core的高级API,.../CUDA/v??.?/include/crt/mma.h中定义了这些API。具体地来说,需要声明matrix_a/matrix_b/accumulator这三种矩阵的fragment(一个fragment对应一个warp的所有线程的某一个或几个寄存器),使用load_matrix_sync和store_matrix_sync将矩阵写入寄存器或将矩阵写回shared memory或global memory,使用mma_sync来调用Tensor Core计算矩阵乘法。For example:
nvcuda::wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a;
nvcuda::wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> frag_a;
nvcuda::wmma::fragment<wmma::accumulator, 16, 16, 16, half> frag_c;
nvcuda::wmma::fill_fragment(frag_c, 0.0);
nvcuda::wmma::load_matrix_sync(frag_a, (shared memory or global memory pointer), (stride_a));
nvcuda::wmma::load_matrix_sync(frag_b, (shared memory or global memory pointer), (stride_b));
nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);
nvcuda::wmma::store_matrix_sync((shared memory or global memory pointer), frag_c, (stride_c), wmma::mem_row_major);
这里不同compute capability的GPU支持不同大小的fragment,具体的可以查看《CUDA C++ Programming Guide》。
3.2 PTX指令
在《Parallel Thread Execution ISA》中,9.7.13.3节和9.7.13.4节分别给出了两种指令:wmma指令和mma指令,个人感觉这两类指令可以说是非常类似,其中wmma指令更像是Volta架构的遗留产物。
wmma指令包括:
// wmma.load
wmma.load.a.sync.aligned.layout.shape{.ss}.atype r, [p] {, stride};
wmma.load.b.sync.aligned.layout.shape{.ss}.btype r, [p] {, stride};
wmma.load.c.sync.aligned.layout.shape{.ss}.ctype r, [p] {, stride};
// wmma.store
wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};
// wmma.mma
wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype d, a, b, c; // fp16
wmma.mma.sync.aligned.alayout.blayout.shape.s32.atype.btype.s32{.satfinite} d, a, b, c; // int8 uint8
wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c; // bf16
wmma.mma.sync.aligned.alayout.blayout.shape.f32.atype.btype.f32 d, a, b, c; // tf32
wmma.mma.sync.aligned.alayout.blayout.shape{.rnd}.f64.f64.f64.f64 d, a, b, c; // fp64
wmma.mma.sync.aligned.row.col.shape.s32.atype.btype.s32{.satfinite} d, a, b, c; // int4 uint4
wmma.mma.op.popc.sync.aligned.row.col.shape.s32.atype.btype.s32 d, a, b, c; // int1
mma指令包括:
// mma
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c; // fp16
mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype d, a, b, c; // fp16
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c; // fp16
mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 d, a, b, c; // bf16 tf32
mma.sync.aligned.m16n8k8.row.col.f32.atype.btype.f32 d, a, b, c; // bf16 tf32
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 d, a, b, c; // bf16 tf32
mma.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c; // int8 uint8
mma.sync.aligned.shape.row.col{.satfinite}.s32.atype.btype.s32 d, a, b, c; // int4 uint4
mma.sync.aligned.shape.row.col.s32.b1.b1.s32.bitOp.popc d, a, b, c; // int1
// load matrix
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];
wmma和mma指令中的矩阵运算指令可以说是非常相似了,但是load指令有所不同:wmma.load指令对矩阵的每行都按照stride访问,而ldmatrix指令则可以对每4个线程对应的元素指定一个地址,所以ldmatrix的访问方式更加灵活,矩阵元素在shared memory中的排放就可以灵活地调整以避免bank conflict。
3.3 SASS指令
不管使用C++ API还是嵌入式的PTX指令,最终都要编译成SASS机器码,fp16类型对应的上述的矩阵load和矩阵乘法均被编译成LSDM指令和HMMA指令。
4. 一个简单的Tensor Core HGEMM Kernel
CUDA矩阵乘法算子的矩阵分块的考量在这篇文章中已经介绍过:
从计算访存比的角度来说,计算访存比跟(1 / BM + 1 / BN)成正比,也就是说为了让访存带宽不成为瓶颈,我们倾向于让BM和BN越大越好;但是由于BM * BN的accumulator要存放在寄存器中,寄存器数目限制了BM和BN不能无限大。关于BK的取值,首先BK至少需要是nvcuda::wmma::fragment中定义矩阵的K维度的整数倍;当BK太小(例如取BK = 16)时,核心循环中HMMA指令占比不高,一些循环相关的地址计算的指令会导致性能下降;当BK >= 32时,因为BK不影响计算访存比,我们发现性能基本不会再随BK而提高了;另外还有(BM + BN) * BK还受到shared memory大小的约束。
GEMM Block Tiling
这里我们取BM = 128,BN = 256,BK = 32,thread_per_block = 256。这样每次K循环中,256个线程每个线程需要取16个矩阵A的元素,取32个矩阵B的元素;8个warp每个warp负责计算64x32x64的矩阵乘法。为了方便起见假设M/N/K对齐到128/256/32,也就是没有处理corner case。这份代码调用的C++ wmma的API,代码如下:
__global__ void myHGEMMAlignedV1(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 256;
const int BK = 32;
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int wid = tid >> 5;
const int APAD = 8;
const int BPAD = 8;
__shared__ half s_a[BM][BK + APAD];
__shared__ half s_b[BK][BN + BPAD];
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a[2][4];
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> frag_b[2][4];
wmma::fragment<wmma::accumulator, 16, 16, 16, half> frag_c[4][4];
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::fill_fragment(frag_c[i][j], 0.0);
}
}
int load_a_smem_m = (tid >> 2) << 1;
int load_a_smem_k = (tid & 3) << 3;
int load_b_smem_k = (tid >> 5) << 2;
int load_b_smem_n = (tid & 31) << 3;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_smem_k, K);
int load_b_gmem_addr = OFFSET(load_b_smem_k, load_b_gmem_n, N);
int comp_c_frag_m = wid & 1;
int comp_c_frag_n = wid >> 1;
for (int bk = 0; bk < K / BK; bk++) {
FLOAT4(s_a[load_a_smem_m ][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr ]);
FLOAT4(s_a[load_a_smem_m + 1][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr + K]);
FLOAT4(s_b[load_b_smem_k ][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr ]);
FLOAT4(s_b[load_b_smem_k + 1][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr + N]);
FLOAT4(s_b[load_b_smem_k + 2][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr + 2 * N]);
FLOAT4(s_b[load_b_smem_k + 3][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr + 3 * N]);
load_a_gmem_addr += BK;
load_b_gmem_addr += BK * N;
__syncthreads();
wmma::load_matrix_sync(frag_a[0][0], &s_a[comp_c_frag_m * 64 ][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[comp_c_frag_m * 64 + 16][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[comp_c_frag_m * 64 + 32][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[comp_c_frag_m * 64 + 48][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[comp_c_frag_m * 64 ][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[comp_c_frag_m * 64 + 16][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[comp_c_frag_m * 64 + 32][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[comp_c_frag_m * 64 + 48][16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[ 0][comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[ 0][comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[ 0][comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[ 0][comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[16][comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[16][comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[16][comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[16][comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
__syncthreads();
}
int store_c_gmem_m = by * BM + comp_c_frag_m * 64;
int store_c_gmem_n = bx * BN + comp_c_frag_n * 64;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::store_matrix_sync(&c[store_c_gmem_addr + i * 16 * N + j * 16], frag_c[i][j], N, wmma::mem_row_major);
}
}
}
其中需要注意的地方是,为了避免LDSM指令从shared memory中取数时发生bank conflict,因此shared memory中每行矩阵后面都加了16 Bytes的pad,有兴趣的同学可以画一画矩阵在shared memory中的排布,思考一下为什么每行加16 Bytes就可以避免bank conflict。
这里避免bank conflict的方式非常naive,会造成shared memory的浪费(虽然shared memory也够用了)。CUTLASS中采用了这样一种排布方式:https://developer.download.nvidia.cn/video/gputechconf/gtc/2019/presentation/s9593-cutensor-high-performance-tensor-operations-in-cuda-v2.pdf,可以在不额外增加shared memory占用的情况下,同时避免读写shared memory时的冲突。这种方法因为在load矩阵时需要为每四个线程指定一个shared memory的地址,不能使用stride访问,所以编程时C++ API和PTX的wmma指令都不适用,需要使用PTX中的ldmatrix指令。
5. Global Memory到Shared Memory的异步拷贝
在Ampere架构以前,global memory到shared memory的数据拷贝需要寄存器的参与,即先从global memory加载到寄存器,再从寄存器写到shared memory;Ampere架构引入了global memory到shared memory的异步拷贝的特性,不需要在寄存器中转数据,还有利于节省中间寄存器的使用。
Global memory到shared memory的异步拷贝,cuda cooperative_groups和pipeline中均有C++ API的支持,但是该接口cooperative_groups::memcpy_async(group, p_smem, p_gmem, size)仅支持了连续数据的拷贝,而矩阵乘法算子中加载的数据并不连续,需要间隔stride访问,因此我使用了PTX嵌入式汇编。
PTX指令中的异步拷贝指令共有四条,除了指定dst、src和size,还可以指定L1和L2 cache的一些行为:
cp.async.ca.shared.global{.level::cache_hint}{.level::prefetch_size}
[dst], [src], cp-size{, src-size}{, cache-policy} ;
cp.async.cg.shared.global{.level::cache_hint}{.level::prefetch_size}
[dst], [src], 16{, src-size}{, cache-policy} ;
cp.async.ca.shared.global{.level::cache_hint}{.level::prefetch_size}
[dst], [src], cp-size{, ignore-src}{, cache-policy} ;
cp.async.cg.shared.global{.level::cache_hint}{.level::prefetch_size}
[dst], [src], 16{, ignore-src}{, cache-policy} ;
.level::cache_hint = { .L2::cache_hint }
.level::prefetch_size = { .L2::64B, .L2::128B, .L2::256B }
cp-size = { 4, 8, 16 }
异步拷贝指令后需要使用cp.async.commit_group指令+cp.async.wait_group指令,或者cp.async.wait_all指令来等待指定的拷贝指令完成数据拷贝。
我们把GEMM Kernel中矩阵A和矩阵B global memory到shared memory的数据拷贝替换成异步拷贝:
__global__ void myHGEMMAlignedV2(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 256;
const int BK = 32;
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int wid = tid >> 5;
const int APAD = 8;
const int BPAD = 8;
__shared__ half s_a[BM][BK + APAD];
__shared__ half s_b[BK][BN + BPAD];
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a[2][4];
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> frag_b[2][4];
wmma::fragment<wmma::accumulator, 16, 16, 16, half> frag_c[4][4];
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::fill_fragment(frag_c[i][j], 0.0);
}
}
int load_a_smem_m = (tid >> 2) << 1;
int load_a_smem_k = (tid & 3) << 3;
int load_b_smem_k = (tid >> 5) << 2;
int load_b_smem_n = (tid & 31) << 3;
int s_a_base_addr = __cvta_generic_to_shared(s_a[0]);
int s_b_base_addr = __cvta_generic_to_shared(s_b[0]);
int load_a_smem_addr_0 = s_a_base_addr + OFFSET(load_a_smem_m, load_a_smem_k, BK + APAD) * sizeof(half);
int load_a_smem_addr_1 = load_a_smem_addr_0 + (BK + APAD) * sizeof(half);
int load_b_smem_addr_0 = s_b_base_addr + OFFSET(load_b_smem_k, load_b_smem_n, BN + BPAD) * sizeof(half);
int load_b_smem_addr_1 = load_b_smem_addr_0 + (BN + BPAD) * sizeof(half);
int load_b_smem_addr_2 = load_b_smem_addr_0 + 2 * (BN + BPAD) * sizeof(half);
int load_b_smem_addr_3 = load_b_smem_addr_0 + 3 * (BN + BPAD) * sizeof(half);
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_smem_k, K);
int load_b_gmem_addr = OFFSET(load_b_smem_k, load_b_gmem_n, N);
int comp_c_frag_m = wid & 1;
int comp_c_frag_n = wid >> 1;
for (int bk = 0; bk < K / BK; bk++) {
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_0), "l"(&a[load_a_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_1), "l"(&a[load_a_gmem_addr + K]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_0), "l"(&b[load_b_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_1), "l"(&b[load_b_gmem_addr + N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_2), "l"(&b[load_b_gmem_addr + 2 * N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_3), "l"(&b[load_b_gmem_addr + 3 * N]));
load_a_gmem_addr += BK;
load_b_gmem_addr += BK * N;
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
wmma::load_matrix_sync(frag_a[0][0], &s_a[comp_c_frag_m * 64 ][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[comp_c_frag_m * 64 + 16][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[comp_c_frag_m * 64 + 32][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[comp_c_frag_m * 64 + 48][ 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[comp_c_frag_m * 64 ][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[comp_c_frag_m * 64 + 16][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[comp_c_frag_m * 64 + 32][16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[comp_c_frag_m * 64 + 48][16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[ 0][comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[ 0][comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[ 0][comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[ 0][comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[16][comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[16][comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[16][comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[16][comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
__syncthreads();
}
int store_c_gmem_m = by * BM + comp_c_frag_m * 64;
int store_c_gmem_n = bx * BN + comp_c_frag_n * 64;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::store_matrix_sync(&c[store_c_gmem_addr + i * 16 * N + j * 16], frag_c[i][j], N, wmma::mem_row_major);
}
}
}
这里需要注意嵌入式的PTX汇编中,shared memory的指针需要特殊处理一下。因为用&smem[...]这样得到的是generic的指针(8字节),直接该8字节值作为shared memory的地址可能会超出shared memory的地址范围,所以需要使用__cvta_generic_to_shared()或者将该8字节值与上0xFFFFFF,使该指针指向shared memory的地址空间。详情参见Problem about PTX instruction cp.async.ca.shared.global - CUDA Programming and Performance - NVIDIA Developer Forums。
Global Memory到Shared Memory异步拷贝的加入大概带来了5 TFLOPS ~ 10 TFLOPS的性能提升:
With/Without AsyncCopy Performance
6. Double Buffer
Double Buffer的目的是,在从global memory向shared memory加载下一次计算使用的数据时,刚好进行本次计算,以掩盖访存延迟。其实我也没有非常清楚double buffer的算子最好应该怎么写,下面的代码只是我自己尝试的一种测出来有性能提升的写法,另外我看到nvidia forum上有使用C++ API中的pipeline来实现double buffer的。
__global__ void myHGEMMAlignedV3(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 256;
const int BK = 32;
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int wid = tid >> 5;
const int APAD = 8;
const int BPAD = 8;
extern __shared__ half smem[];
half *s_a = smem;
half *s_b = smem + 2 * BM * (BK + APAD);
int s_a_db_offset = BM * (BK + APAD);
int s_b_db_offset = BK * (BN + BPAD);
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a[2][4];
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> frag_b[2][4];
wmma::fragment<wmma::accumulator, 16, 16, 16, half> frag_c[4][4];
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::fill_fragment(frag_c[i][j], 0.0);
}
}
int load_a_smem_m = (tid >> 2) << 1;
int load_a_smem_k = (tid & 3) << 3;
int load_b_smem_k = (tid >> 5) << 2;
int load_b_smem_n = (tid & 31) << 3;
int s_a_base_addr = __cvta_generic_to_shared(s_a);
int s_b_base_addr = __cvta_generic_to_shared(s_b);
int load_a_smem_addr_0 = s_a_base_addr + OFFSET(load_a_smem_m, load_a_smem_k, BK + APAD) * sizeof(half);
int load_a_smem_addr_1 = load_a_smem_addr_0 + (BK + APAD) * sizeof(half);
int load_b_smem_addr_0 = s_b_base_addr + OFFSET(load_b_smem_k, load_b_smem_n, BN + BPAD) * sizeof(half);
int load_b_smem_addr_1 = load_b_smem_addr_0 + (BN + BPAD) * sizeof(half);
int load_b_smem_addr_2 = load_b_smem_addr_0 + 2 * (BN + BPAD) * sizeof(half);
int load_b_smem_addr_3 = load_b_smem_addr_0 + 3 * (BN + BPAD) * sizeof(half);
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_smem_k, K);
int load_b_gmem_addr = OFFSET(load_b_smem_k, load_b_gmem_n, N);
int comp_c_frag_m = wid & 1;
int comp_c_frag_n = wid >> 1;
{
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_0), "l"(&a[load_a_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_1), "l"(&a[load_a_gmem_addr + K]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_0), "l"(&b[load_b_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_1), "l"(&b[load_b_gmem_addr + N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_2), "l"(&b[load_b_gmem_addr + 2 * N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_3), "l"(&b[load_b_gmem_addr + 3 * N]));
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
}
for (int bk = 1; bk < K / BK; bk++) {
int smem_sel = (bk & 1) ^ 1;
int smem_sel_next = ((bk - 1) & 1) ^ 1;
load_a_gmem_addr += BK;
load_b_gmem_addr += BK * N;
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_0 + smem_sel_next * s_a_db_offset * (int)sizeof(half)), "l"(&a[load_a_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_1 + smem_sel_next * s_a_db_offset * (int)sizeof(half)), "l"(&a[load_a_gmem_addr + K]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_0 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_1 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_2 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + 2 * N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_3 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + 3 * N]));
wmma::load_matrix_sync(frag_a[0][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
}
int smem_sel = ((K / BK) & 1) ^ 1;
wmma::load_matrix_sync(frag_a[0][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
int store_c_gmem_m = by * BM + comp_c_frag_m * 64;
int store_c_gmem_n = bx * BN + comp_c_frag_n * 64;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::store_matrix_sync(&c[store_c_gmem_addr + i * 16 * N + j * 16], frag_c[i][j], N, wmma::mem_row_major);
}
}
}
这里需要注意的是,double buffer会用到两倍的shared memory,当使用的shared memory超过48 KB时,需要使用dynamic shared memory,即extern shared half smem[];这样声明一块动态共享内存,调用kernel时需要指定动态共享内存大小,且smem的寻址方式需要按照一维数组来使用:
const int BM = 128, BN = 256, BK = 32;
dim3 blockDim(256);
int BX = (N + BN - 1) / BN;
int BY = (M + BM - 1) / BM;
dim3 gridDim(BX, BY);
cudaFuncSetAttribute(gemmBK32WmmaAsyncDSMemDB, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
unsigned int dsmem = 2 * (BM * (BK + 8) + BK * (BN + 8)) * sizeof(half);
gemmBK32WmmaAsyncDSMemDB<<<gridDim, blockDim, dsmem>>>(a, b, c, M, N, K);
Double Buffer的效果可谓是立竿见影,带来了大概20 TFLOPS ~ 25 TFLOPS的提升:
With/Without Double Buffer Performance
7. 提高L2 Cache的局部性
RTX3090一共有82个SM,经过计算gemmBK32WmmaAsyncDSMemDB这个kernel每个SM只能容纳一个block,当大规模矩阵乘法的block数目超过82时,会按照gridDim.z -> gridDim.y -> gridDim.x这样的循环顺序进行调度。
例如当M = N = K = 16384时,矩阵C会被分块成128 * 64个Tile,如果按照正常的调度顺序,先调度矩阵C第一行64个Tile对应的block加上第二行的前18个block,这样虽然矩阵A的局部性很好,但是矩阵B的访存局部性极差。我们现在希望第一次先调度第一行到第五行的前16个block,加上第六行的前2个block,这样矩阵A和矩阵B的局部性就得到了平衡。
修改一下调用kernel时的代码,利用其默认的调度顺序,加上gridDim.z这一维,这里NSPLIT就代表矩阵C的一行一次调度到NSPLIT这么多就转到下一行:
const int BM = 128, BN = 256, BK = 32;
dim3 blockDim(256);
int BX = (N + BN - 1) / BN;
int BY = (M + BM - 1) / BM;
const int NSPLIT = 4096;
int split_num = (N + NSPLIT - 1) / NSPLIT;
dim3 gridDim((BX + split_num - 1) / split_num, BY, split_num);
cudaFuncSetAttribute(gemmBK32WmmaAsyncDSMemDB, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
unsigned int dsmem = 2 * (BM * (BK + 8) + BK * (BN + 8)) * sizeof(half);
gemmBK32WmmaAsyncDSMemDB<<<gridDim, blockDim, dsmem>>>(a, b, c, M, N, K);
相应地修改kernel:
__global__ void myHGEMMAlignedV4(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
// ...
// int bx = blockIdx.x; // 原来是这样
int bx = blockIdx.z * gridDim.x + blockIdx.x; // 现在是这样
if (bx >= N / BN || by >= M / BM)
return;
// ...
}
想法很丰满,现实很骨感,测试发现NSPLIT = 256时性能很差,而NSPLIT = 512/1024/2048/4096/8192时和myHGEMMAlignedV3相差无几,只有在接近16384的几个样本点性能表现明显更好。为了让优化代码不要白写,我最终选取了NSPLIT = 4096:
NSPLIT Performance
8. 给编译器一些发挥的空间
事实上优化到myHGEMMAlignedV3这里添加了double buffer之后,就已经达到并大致超过了cublas。最后让编译器给主循环进行循环展开,看看能再有多大的性能提升:
__global__ void myHGEMMAlignedV4(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
// ...
#pragma unroll 32
for (int bk = 1; bk < K / BK; bk++) {
// ...
}
// ...
}
测试结果如下图所示,循环展开比不循环展开时又提高了约15 TFLOPS,基本全面超过cublas。在M = N = K > 4096时,循环展开到8之后性能基本不会再有提升;但MNK较小时,直到展开32次仍然还有提升。
UNROLL Performance
9. At last
至此就是本次文章的全部内容,学习了一下Tensor Core的使用方法,并顺利地将fp16的矩阵乘法优化到了cublas的性能。当然,其中一部分性能还来自于我假设MNK都是对齐的,没有判断corner case。
总体性能图(原本很清晰的图上传之后就不清晰了qwq):
Performance下篇文章预计会研究一下卷积算子。