本文介绍最简单的pytorch分布式训练方法:使用torch.nn.DataParallel这个API来实现分布式训练。环境为单机多gpu,不妨假设有4个可用的gpu。

一、构建方法

使用这个API实现分布式训练的步骤非常简单,总共分为3步骤:
1、创建一个model,并将该model推到某个gpu上(这个gpu也将作为output_device,后面具体解释含义),不妨假设推到第0号gpu上,

device = torch.device("cuda:0")
model.to(device)

2、将数据推到output_device对应的gpu上,

data = data.to(device)

3、使用torch.nn.DataParallel这个API来在0,1,2,3四个gpu上构建分布式模型,

model = torch.nn.DataParallel(model, device_ids=[0,1,2,3], output_device=0)

然后这个model就可以像普通的单gpu上的模型一样开始训练了。

二、原理详解

2.1 原理图

  首先通过图来看一下这个最简单的分布式训练API的工作原理,然后结合代码详细阐述。

深度学习 分布式训练 分布式模型训练_pytorch


将模型和数据推入output_device(也就是0号)gpu上。

深度学习 分布式训练 分布式模型训练_深度学习 分布式训练_02


0号gpu将当前模型在其他几个gpu上进行复制,同步模型的parameter、buffer和modules等;将当前batch尽可能平均的分为len(device)=4份,分别推给每一个设备,并开启多线程分别在每个设备上进行前向传播,得到各自的结果,最后将各自的结果全部汇总在一起,拷贝回0号gpu。

深度学习 分布式训练 分布式模型训练_数据_03


在0号gpu进行反向传播和模型的参数更新,并将结果同步给其他几个gpu,即完成了一个batch的训练。

2.2 代码原理

  通过分析torch.nn.DataParallel的代码,可以看到具体的过程,这里重点看一下几个关键的地方。

# 继承自nn.Module,只要实现__init__和forward函数即可
class DataParallel(Module):
    # 构造函数里没有什么关键内容,主要是根据传进来的model、device_ids和output_device进行一些变量生成
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        device_type = _get_available_device_type()
        if device_type is None:
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = _get_all_device_indices()

        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device(device_type, self.device_ids[0])

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.to(self.src_device_obj)
    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
                                   "them on device: {}".format(self.src_device_obj, t.device))

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        # 在每个gpu上都复制一个model
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        # 开启多线程进行前向传播,得到结果
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        # 将每个gpu上得到的结果都gather到0号gpu上
        return self.gather(outputs, self.output_device)

    def replicate(self, module, device_ids):
        return replicate(module, device_ids, not torch.is_grad_enabled())

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

    def gather(self, outputs, output_device):
        return gather(outputs, output_device, dim=self.dim)

再看一下parallel_apply这个关键的函数,

def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    # 创建一个互斥锁,防止前后两个batch的数据覆盖
    lock = threading.Lock()
    results = {}
    grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
    # 线程的target函数,实现每个gpu上进行推理,其中i为gpu编号
    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            # 根据当前gpu编号确定推理硬件环境
            with torch.cuda.device(device), autocast(enabled=autocast_enabled):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module(*input, **kwargs)
            # 锁住赋值,防止后一个batch的数据将前一个batch的结果覆盖
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        # 创建多个线程,进行不同gpu的前向推理
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
    # 将不同gpu上推理的结果打包起来,后面会gather到output_device上
    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs

结论

  至此我们看到了torch.nn.DataParallel模块进行分布式训练的原理,数据和模型首先推入output_device对应的gpu,然后将分成多个子batch的数据和模型分别推给其他gpu,每个gpu单独处理各自的子batch,结果再打包回原output_device对应的gpu算梯度和更新参数,如此循环往复,其本质是一个单进程多线程的并发程序。
  由此我们也很容易得到torch.nn.DataParallel模块进行分布式的缺点,
1、每个batch的数据先分发到各gpu上,结果再打包回output_device上,在output_device一个gpu上进行梯度计算和参数更新,再把更新同步给其他gpu上的model。其中涉及数据的来回拷贝,网络通信耗时严重,GPU使用率低。
2、这种模式只支持单机多gpu的硬件拓扑结构,不支持Apex的混合精度训练等。
3、torch.nn.DataParallel也没有很完整的考虑到多个gpu做数据并行的一些问题,比如batchnorm,在训练时各个gpu上的batchnorm的mean和variance是子batch的计算结果,而不是原来整个batch的值,可能会导致训练不稳定影响收敛等问题。