前几篇文章我们介绍了 PyTorch 流水线并行的基本知识和自动平衡机制,本文我们介绍如何切分数据和运行时系统。



[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统



目录



0x00 摘要

前几篇文章我们介绍了 PyTorch 流水线并行的基本知识和自动平衡机制,本文我们介绍如何切分数据和运行时系统。

最后得出运行时系统如下:

[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统_封装

0x01 分割小批次

我们首先看看如何把一个 mini-batch 分割为多个 micro-batches。

1.1 使用

从下面示例代码可以看出来,具体使用scatter方法进行了分割。

# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)

# Run pipeline parallelism.
pipeline = Pipeline(batches,
self.partitions,
self.devices,
copy_streams,
self._skip_layout,
checkpoint_stop)

pipeline.run()

# Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output


1.2 PyTorch 基础

我们先看看 PyTorch 的一些基础代码。

1.2.1 chunk

chunk方法可以对张量分块,返回一个张量列表,其参数是:

  • ensor :要分割的张量。
  • chunks : 分割的块数
  • dim :沿着哪个轴分块

具体举例如下:

import numpy as np
import torch

data = torch.from_numpy(np.random.rand(3, 5))
print(str(data))

for i, data_i in enumerate(data.chunk(3, 0)): # 沿0轴分为3块
print(str(data_i))

输出
tensor([[0.1208, 0.3428, 0.4586, 0.9372, 0.6410],
[0.7889, 0.4480, 0.7607, 0.7903, 0.4118],
[0.8391, 0.6649, 0.8338, 0.3477, 0.3953]], dtype=torch.float64)

tensor([[0.1208, 0.3428, 0.4586, 0.9372, 0.6410]], dtype=torch.float64)
tensor([[0.7889, 0.4480, 0.7607, 0.7903, 0.4118]], dtype=torch.float64)
tensor([[0.8391, 0.6649, 0.8338, 0.3477, 0.3953]], dtype=torch.float64)


1.2.2 cat

cat 的用法则是把张量拼接在一起,或者把一个张量列表拼接起来。

Z = torch.cat( (X,Y),0 )  # 按维数0拼接,就是竖着拼
Z = torch.cat( (X,Y),1 ) # 按维数1拼接,就是横着拼


我们用示例看看:

X = torch.ones(2, 5)
Y = torch.ones(4, 5)
Z = torch.cat((X, Y), 0)
print(Z)


结果是:

tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])


1.3 分割 & 聚合

具体回到分割批次,我们来看看Scatter 代码。

def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
"""Splits an input mini-batch into multiple micro-batches."""
inputs: Iterable[TensorOrTensors]

if isinstance(input, Tensor):
inputs = input.chunk(chunks) # 如果是张量,则直接分割
else:
rotated: List[Tensors] = []

for tensor in input: # 如果是张量数组,则遍历
tensors = tensor.chunk(chunks) # 对于每一个张量进行分割
rotated.append(cast(Tensors, tensors)) # 分割结果映射为 Tuple list

inputs = zip(*rotated) # 把 list 之中的Tuple 分别聚合

return [Batch(x) for x in inputs] # 映射成 Batch 列表返回


gather 方法则是把scatter的结果重新聚集起来,就是一个逆向操作。

def gather(outputs: List[Batch]) -> TensorOrTensors:
"""Concatenates output micro-batches into a mini-batch."""
output: TensorOrTensors

if outputs[0].atomic:
tensors = tuple(b.tensor for b in outputs)
output = torch.cat(tensors)
else:
rotated = [b.tensors for b in outputs]
output_buf = []

for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))

output = tuple(output_buf)

return output


1.4 剖析

我们看看如何使用,下面代码是把ab这个张量列表打散,分割成两个块。

def test_scatter_tuple():
ab = (torch.ones(2, 1), torch.zeros(4, 2), torch.zeros(6, 3))

a, b = scatter(ab, chunks=2)

assert a.tensors[0].size() == (1, 1)
assert b.tensors[0].size() == (1, 1)
assert a.tensors[1].size() == (2, 2)
assert b.tensors[1].size() == (2, 2)
assert a.tensors[2].size() == (3, 3)
assert b.tensors[2].size() == (3, 3)


我们画个图来看看。

    +-------------------------------------------------------------+
| ab |
| |
| +-----------+ +---------+ +----------+ |
| | | | | | 0 0 0 | |
| | | | 0 0 | | 0 0 0 | |
| | 1 | | 0 0 | | 0 0 0 | |
| | 1 | | 0 0 | | 0 0 0 | |
| | | | 0 0 | | 0 0 0 | |
| | | | | | 0 0 0 | |
| +-----------+ +---------+ +----------+ |
| |
+-------------------------------+-----------------------------+
|
|
|
a, b = scatter(ab, chunks=2)
|
|
|
|
|
v


+------------------------------+ +-----------------------------+
| a | |b |
| +---+ +-----+ +--------+ | | +---+ +-----+ +--------+ |
| | 1 | | 0 0 | | 0 0 0 | | | | 1 | | 0 0 | | 0 0 0 | |
| +---+ | 0 0 | | 0 0 0 | | | +---+ | 0 0 | | 0 0 0 | |
| +-----+ | 0 0 0 | | | +-----+ | 0 0 0 | |
| +--------+ | | +--------+ |
+------------------------------+ +-----------------------------+


使用下面的示例代码也可以看到如何聚合。

def test_gather_tensors():
a = torch.zeros(1, 1)
b = torch.zeros(1, 1)
ab = gather([Batch(a), Batch(b)])

assert ab.size() == (2, 1)


def test_gather_tuples():
a = (torch.zeros(1, 1), torch.zeros(2, 2))
b = (torch.zeros(1, 1), torch.zeros(2, 2))
ab = gather([Batch(a), Batch(b)])

assert isinstance(ab, tuple)
assert ab[0].size() == (2, 1)
assert ab[1].size() == (4, 2)


0x02 运行

我们接下来看看运行时的一些基础设施,具体包括 Stream,Task,Worker。

2.1 Stream

Stream 类是用来封装 CUDA stream 和 CPU stream。代码位于:torchgpipe/stream.py。

CUDA流表示一个GPU操作队列,即某个设备绑定的,按照顺序执的核(kernel)序列。我们可以把一个流看作是GPU之上的一个任务。用户向流的队列上添加一系列操作,GPU会按照添加到流中的先后顺序而依次执行这一系列操作。在同一个流之中,所有操作是串行序列化,因此这些操作永远不会并行。因此,要想并行,两个操作必须位于不同的 stream 中。不同流中的核函数可以交错,甚至可能重叠。

class CPUStreamType:
pass

# The placeholder on place of streams for the CPU device instead of CUDA.
CPUStream = CPUStreamType()

# It represents both CUDA streams and the CPU stream.
AbstractStream = Union[torch.cuda.Stream, CPUStreamType]


本文用到的相关操作为 use_stream。

torch.cuda.stream(stream) 的作用是选择给定流的上下文管理器。

@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not is_cuda(stream):
yield
return

with torch.cuda.stream(as_cuda(stream)):
yield

def is_cuda(stream: AbstractStream) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream


def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
"""Casts the given stream as :class:`torch.cuda.Stream`."""
return cast(torch.cuda.Stream, stream)


2.2 Task

Task 表示如何在一个分区上计算微批次数据(micro-batch)。它由两部分组成:

  • ​compute​​应在工作线程中并发执行。
  • ​finalize​​应在工作线程完成后执行。

可以理解为一个业务处理逻辑。如果有安卓经验的同学,可以理解为类似于 业务Message。其实 Android message也叫task,其封装了本任务携带的信息和处理该任务的handler。

这里的 Task 也是类似的,在构建Task 时候,就传入了 compute 方法和finalize方法,举例如下:

task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)


或者如下:

def compute(batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
) -> Batch:
with use_skip_tracker(skip_tracker):
return batch.call(partition)

task = Task(streams[j], compute=compute, finalize=None)


具体Task定义如下,Task是绑定在 Stream 之上,即可以运行在任何device之上,这就用到了上一节的内容。

class Task:
"""A task represents how to compute a micro-batch on a partition.

It consists of two parts: :meth:`compute` and :meth:`finalize`.
:meth:`compute` should be executed in worker threads concurrently.
:meth:`finalize` should be executed after when worker threads complete to
execute :meth:`compute`.

:meth:`compute` might be boosted by worker threads. Because it produces
several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
are not serialized through GIL. So more than one CUDA API call can be
produced at the same time.
"""

def __init__(self,
stream: AbstractStream,
*,
compute: Callable[[], Batch],
finalize: Optional[Callable[[Batch], None]],
) -> None:
self.stream = stream
self._compute = compute
self._finalize = finalize

def compute(self) -> Batch:
with use_stream(self.stream): # 绑定在stream之上
return self._compute() # 调用传入的业务代码

def finalize(self, batch: Batch) -> None:
if self._finalize is None:
return
with use_stream(self.stream): # 绑定在stream之上
self._finalize(batch) # 调用传入的业务代码


2.3 Worker

worker是用来运行task的,每个 device 有一个 worker 来负责执行这个 device 上的 task。如果有安卓经验的同学,可以理解为是 Looper。

需要注意,worker只是一个函数,如果运行,还需要一个线程作为寄托。这就是后续 spawn_workers 的工作。

def worker(in_queue: InQueue,
out_queue: OutQueue,
device: torch.device,
grad_mode: bool,
) -> None:
"""The main loop of a worker thread."""
torch.set_grad_enabled(grad_mode)

with use_device(device):
while True:
task = in_queue.get() # 从输入队列中获取task

if task is None:
break

try:
batch = task.compute() # 计算task
except Exception:
exc_info = cast(ExcInfo, sys.exc_info())
out_queue.put((False, exc_info))
continue

out_queue.put((True, (task, batch))) # 把task和计算结果放到输出队列

done = (False, None)
out_queue.put(done)


2.4 生成 worker

这里使用了 @contextmanager 注解,这是实现了上下文管理协议的对象,主要用于保存和恢复各种全局状态,关闭文件等,并为try...except...finally提供了一个方便使用的封装。

spawn_workers 为每个 device 生成了一个 Thread,这个 Thread 的执行函数是 worker

spawn_workers 不止生成了若干 workers,也生成了一对消息队列 (in_queues, out_queues) ,这个 (in_queues, out_queues) 在Pipeline 生命周期之内全程都存在,具体来说是:

  • spawn_workers 内部会针对每一个device生成一个 in_queue, out_queue。所以可保证每个device之上是串行来执行业务操作。
in_queue, out_queue = workers[device]


  • 这些 queues 被添加到 (in_queues, out_queues) 之中。
in_queues.append(in_queue)
out_queues.append(out_queue)


  • 之后就是使用 (in_queues, out_queues) 作为各个task 之间传递信息的上下文。
  • in_queues 里面的顺序就是 device 的顺序,也就是partition的顺序。out_queues 亦然。

具体代码如下:

@contextmanager
def spawn_workers(devices: List[torch.device],
) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
"""Spawns worker threads. A worker thread is bound to a device."""
in_queues: List[InQueue] = []
out_queues: List[OutQueue] = []

# Spawn workers.
workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}

def normalize_device(device: torch.device) -> torch.device:
if device.type == 'cuda' and device.index is None:
return torch.device('cuda', index=torch.cuda.current_device())

if device.type == 'cpu' and device.index is not None:
return torch.device('cpu')

return device

for device in devices:
device = normalize_device(device) # 得到使用的设备

try:
in_queue, out_queue = workers[device] # 临时放置queue
except KeyError: # 如果 device 还没有生成对应的queues,则生成
in_queue = Queue() # 生成新的queue
out_queue = Queue()

# 取出queue
workers[device] = (in_queue, out_queue) # 赋值给workers

t = Thread(
target=worker, # Thread的执行程序是 worker 函数
args=(in_queue, out_queue, device, torch.is_grad_enabled()),
daemon=True,
)
t.start() # 启动工作线程

in_queues.append(in_queue) # 插入queue
out_queues.append(out_queue) # 插入queue

try:
yield (in_queues, out_queues) # 返回给调用者
finally:
# Close workers.
for in_queue in set(in_queues):
in_queue.put(None)

# Join running workers.
running = set(out_queues)
while running:
out_queue = running.pop()
ok, payload = out_queue.get()

done = (False, None)
if (ok, payload) == done:
continue

running.add(out_queue)


2.5 使用

2.5.1 何时生成worker

使用例子位于 torchgpipe/pipeline.py,在 Pipeline 类之中的 run 函数中会生成workers。我们可以看到,对于 Pipeline 来说,有意义的就是 (in_queues, out_queues)。

    def run(self) -> None:
"""Runs pipeline parallelism.

It modifies the given batches in place.

"""
batches = self.batches
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout

m = len(batches)
n = len(partitions)

skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

with spawn_workers(devices) as (in_queues, out_queues): # 生成 workers,并且得到队列
for schedule in clock_cycles(m, n): # 这里是按照算法有次序的运行多个fence, compute
self.fence(schedule, skip_trackers)
# 把队列传递进去
self.compute(schedule, skip_trackers, in_queues, out_queues)


2.5.2 剖析

Torchgpipe 使用了 Python 的 Queue 数据结构。

Queue 类实现了一个基本的先进先出(FIFO)容器。

A multi-producer, multi-consumer queue.


其主要方法是:

  • Queue.get([block, [timeout]]) 读队列,从队列尾部移除元素,timeout为等待时间,如果队列满,则阻塞。
  • Queue.put(item, [block, [timeout]]) 写队列,将元素添加到序列尾端,timeout为等待时间,如果队列空,则阻塞。

我个人更习惯于把 (in_queues, out_queues) 理解为类似 Linux 的 管道(Pipe)。

Linux 管道是一种最基本的IPC机制,作用于有血缘关系的进程之间,完成数据传递,具体特性如下:

  • 管道是由核函数管理的一个FIFO文件,其实是一个缓冲区,相当于我们放入内存中的一个管道,两个进程分别处于管道两端,通过这个管道来传递信息。
  • 管道的一端连接一个进程的输出。这个进程会向管道中放入信息。当管道被放满信息的时候,尝试放入信息的进程会等待,直到另一端的进程取出信息。
  • 管道的另一端连接另一个进程的输入,这个进程取出被放入管道的信息。当管道中没有信息的话,从管道中读取的进程会等待,直到另一端的进程放入信息。

具体回到 TorchPipe,我们提前看看论文的内容:

对于这种细粒度的顺序控制,torchgpipe把checkpointing 使用两个单独的autograd函数Checkpoint和Recompute来实现。在任务 \(F^{'}_{i,j}\) 的执行时间之内,生成一对具有共享内存的Checkpoint和Recompute。该共享内存在向后传播中被使用,用于将通过执行Recompute生成的本地计算图传输到Checkpoint来进行反向传播。

于是,这里就有很多并行处理的需求,于是我们可以看到 Pipeline 类的 compute 方法(省略部分代码)中有​​向 in_queues 之中放入 Task,从 out_queues 之中去除 Task 的执行结果​​。

    def compute(self,
schedule: List[Tuple[int, int]],
skip_trackers: List[SkipTrackerThroughPotals],
in_queues: List[InQueue],
out_queues: List[OutQueue],
) -> None:


# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule: # 并行执行
batch = batches[i]
partition = partitions[j]

# Synchronize with the copied input. ([1] in the diagram)

# Determine whether checkpointing or not.
if checkpoint:
def function(input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker):
return partition(input)

chk = Checkpointing(function, batch)
# 生成一个Task
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk

else:
def compute(batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
) -> Batch:
with use_skip_tracker(skip_tracker):
return batch.call(partition)
# 生成一个Task
task = Task(streams[j], compute=compute, finalize=None)
del compute

# Compute tasks in parallel. ([2] in the diagram)
in_queues[j].put(task) # 给第j个partition放入一个新的task。因为 i, j 已经在clock算法中设定了,所以前向传播就是按照这个来走的。

for i, j in schedule:
ok, payload = out_queues[j].get() # 取出第j个partition的运行结果
# .......

# 省略后续代码


2.6 总结

我们总结梳理一下大致业务逻辑(后文还会细化):

  1. 系统调用 spawn_workers 来生成若干 workers。
  2. spawn_workers 为每个 device 生成了一个 Thread,这个 Thread 的执行函数是 worker。spawn_workers 内部也会针对每一个device生成一个 in_queue, out_queue。所以可保证每个device之上是串行来执行业务操作。
  3. 这些 queues 被添加到 (in_queues, out_queues) 之中。然后把 (in_queues, out_queues) 返回给 Pipeline 主线程。之后就是使用 (in_queues, out_queues) 作为各个task 之间传递信息的上下文。
  4. Pipeline 主线程得到 (in_queues, out_queues) 之后,如果要通过 compute 方法运行一个Task,就找到其device对应的in_queue,把Task插进去。
  5. Worker Thread 阻塞在 in_queue 之上,如果发现有内容,就读取 Task,运行Task。
  6. Worker Thread 把运行结果插入到 out_queue之中。
  7. Pipeline 的 compute 方法会取出 out_queue 之中的运行结果,进行后续处理。

如下图所示:

                           +-------------------------------------------------------------------------+
| 1 |
| +--------------------------------------------------------------+ |
| | 3 (in_queues, out_queues) | |
| v | v
+--------------------------------+---------+ +------+----+-----------------------------------------------------------------------+
| Pipeline | | | spawn_workers |
| | | | |
| | | | +-------------------------------------+ |
| | | | | workers | |
| | | | | | t = Thread( |
| + | | | | target=worker, |
| spawn_workers(devices) | | | device 1 : in_queue 1, out_queue 1 | args=(in_queue, out_queue, device), |
| | | | | daemon=True, |
| | | | device 2 : in_queue 2, out_queue 2 | ) |
| +--------------------------------------+ | | | | t.start() |
| | compute | | | | device 3 : in_queue 3, out_queue 3 | + |
| | | | | | | | |
| | | | 4 | | | | |
| | in_queues[j].put(task) +-----------------------+ | +-------------------------------------+ | |
| | | | | +-----------------------------------------------------------------------------------+
| | | | | | 2
| | ok, payload = out_queues[j].get()<--------+ | +---------------------+ |
| | | | | | | in_queues | v
| +--------------------------------------+ | | | | |
| | | +------------> in_queue 1 +--------+ +---------------------------------------------------------------------+
+------------------------------------------+ | | in_queue 2 | | | Thread |
| | in_queue 3 | | | |
| | | | 5 | +------------------------------------------------------------+ |
| 7 +---------------------+ | | | Worker | |
| +---------------------+ | | | | |
| | out_queues | | | | device 1 task = in_queue.get() | |
| | | | task | | | |
+------------------+ out_queue 1 <--+ | +----------------------> in_queue 1 batch = task.compute() | |
(True, (task,,batch)) | out_queue 2 | | | | | |
| out_queue 3 +---------------------------+ out_queue 1 out_queue.put((True, (task, batch))) | |
| | 6 | | | |
+---------------------+ | +------------------------------------------------------------+ |
+---------------------------------------------------------------------+


手机如下:

[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统_封装

至此,我们分析了如何切分数据和一些运行时机制,下一篇我们结合论文看看具体实现。