1、背景介绍
在pytorch的多卡训练中,通常有两种方式,一种是单机多卡模式(存在一个节点,通过torch.nn.DataParallel(model)实现),一种是多机多卡模式(存在一个节点或者多个节点,通过torch.nn.parallel.DistributedDataParallel(model),在单机多卡环境下使用第二种分布式训练模式具有更快的速度。pytorch在分布式训练过程中,对于数据的读取是采用主进程预读取并缓存,然后其它进程从缓存中读取,不同进程之间的数据同步具体通过torch.distributed.barrier()实现。
2、通俗理解torch.distributed.barrier()
def create_dataloader():
#使用上下文管理器中实现的barrier函数确保分布式中的主进程首先处理数据,然后其它进程直接从缓存中读取
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels()
from contextlib import contextmanager
#定义的用于同步不同进程对数据读取的上下文管理器
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield #中断后执行上下文代码,然后返回到此处继续往下执行
if local_rank == 0:
torch.distributed.barrier()
(1)进程号rank理解
在多进程上下文中,我们通常假定rank 0是第一个进程或者主进程,其它进程分别具有0,1,2不同rank号,这样总共具有4个进程。
(2)单一进程数据处理
通常有一些操作是没有必要以并行的方式进行处理的,如数据读取与处理操作,只需要一个进程进行处理并缓存,然后与其它进程共享缓存处理数据,但是由于不同进程是同步执行的,单一进程处理数据必然会导致进程之间出现不同步的现象,为此,torch中采用了barrier()函数对其它非主进程进行阻塞,来达到同步的目的。
(3)barrier()具体原理
在上面的代码示例中,如果执行create_dataloader()函数的进程不是主进程,即rank不等于0或者-1,上下文管理器会执行相应的torch.distributed.barrier(),设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处(包括主进程数据处理完毕);如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放。