我们知道,pin_memorynon_blocking 可以帮助加速 Pytorch 训练过程。这篇文章里,我以一个比较粗浅的理解分析一下,它们为什么能够加速训练。


pin_memory = True

当我们要在 GPU 上进行训练时,自然需要把数据从CPU(一般情况下,数据存储在 CPU 上)转移到 GPU 上。但是 CPU 与 GPU 之间的数据交互是比较慢的,特别是 CPU 中的 pageable memory (可分页内存)与 GPU 之间的交互,这个过程需要建立一个临时缓冲区(pinned memory)——如下图左边所示。

pytorch 冻结BN pytorch non_blocking_pytorch 冻结BN


在 dataloader 中 设置 pin_memory=True,程序一开始就把数据放在 pinned memory (锁页内存)上,这样向 GPU 传输数据时就会更快,如上图右边所示。这样一来,就避免了 CPU 内部内存的拷贝,节约了时间。

注意: pin_memory=True 虽好,也不要“贪杯”噢。

If you overuse pinned memory, it can cause serious problems when running low on RAM, and you should be aware that pinning is often an expensive operation. ——Use pinned memory buffers

一般设备间传输速度的数量级:

锁页内存和 GPU 显存之间的拷贝速度大约是 6GB/s
可分页内存和 GPU 显存间的拷贝速度大约是 3GB/s。
GPU 内存间速度是 30GB/s,CPU 间内存速度是 10GB/s ——详解 Pytorch 里的 pin_memory 和 non_blocking


non_blocking = True

顾名思义,non_blocking 是指“不要阻塞”,希望运算能够同时进行。那么什么样的运算能够同时进行呢?——比如数据从 CPU 向 GPU 传输时(此时占用的是 GPU 的资源),我们可以在 CPU 上进行运算,充分利用 CPU 的计算资源。

x = x.cuda(non_blocking=True) # 向 GPU 传输数据
compute_on_cpu() # 在 CPU 上进行计算
y = model(x) # 在 GPU 上训练模型

上面代码第一行执行过程中,由于设置了“非阻塞”——non_blocking = True,它允许第一行执行过程中,同时执行第二行代码。当然,如果第二行执行完毕之后,第一行还没执行完,第三行需要等待第一行执行完毕才能开始,因为它们都是在 GPU 上进行的。

可以想像,下面这段代码即使设置 non_blocking = True 也没有效果:

x = x.cuda(non_blocking=True) # 向 GPU 传输数据
y = model(x) # 在 GPU 上训练模型

使用指南

我们经常会同时设置 non_blocking = True 以及 pin_memory=True

Pinned Memory allows the non-blocking calls to actually be non-blocking

上面这句话什么意思呢?如果不设置 pin_memory=True,CPU 向 GPU 传输数据时,需要在 CPU 的内部内存间进行拷贝,建一个临时缓冲区,再由临时缓冲区向 GPU 传输。这样一来,传输数据的时候,CPU 也在忙碌,非阻塞就起不到作用了。

下面是示例代码:

pin_memory = True

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    pin_memory=pin_memory,
    num_workers=n_workers
)

times = []
n_runs = 10
def compute_on_cpu():
    # 模拟 CPU 上的计算
    time.sleep(0.1)

for i in range(n_runs):
    st = time.time()
    for x, y in train_dataloader:
       x, y = x.cuda(non_blocking=pin_memory), y.cuda(non_blocking=pin_memory)
       compute_on_cpu()
   times.append(time.time() - st)
print('average time:', np.mean(times))

另外,搭配 dataloader 的 num_workers,会进一步加快数据的传输速度。至于如何设置 num_workers,其中也有玄机——欲知后事如何,且听下回分解。