python 因 GIL 的存在,处理计算密集型的任务时无法高效利用多核 CPU 的计算资源,这时就需要使用多进程来提高对 CPU 的资源利用。Python 多进程主要用 multiprocessing 模块实现,提供了进程、进程池、队列、管理者、共享数据、同步原语功能。

单进程版

为了便于演示 multiprocessing 的使用,我们使用素数检查模拟计算密集型任务。单进程版本的代码如下:

# encoding:utf8
from math import sqrt
CHECK_NUMBERS = 1000000
def check_prime(n):
if n == 2:
return True
for i in range(2, int(sqrt(n) + 1)):
if n % i == 0:
return False
return True
def run():
total = 0
for i in range(2, CHECK_NUMBERS + 1):
if check_prime(i):
total += 1
return total
if __name__ == "__main__":
import timeit
print(run())
print(timeit.timeit("run()", "from __main__ import run", number=1))

以上单进程的示例在我的计算机上输出结果为:

78498
4.788863064308802

即 1000000 以内共有 78498 个素数,耗时约 4.8 秒。

进程与队列

生成一个新的进程并启动如下:

processes = multiprocessing.Process(target=worker, args=param)
processes.start() # 启动
processes.join() # 主进程等待

子进程中要把结果返回可以通过 Queue 队列返回,Queue 的操作主要是 get 和 put,支持阻塞(可以设置超时),非阻塞,也能监测队列状态是否为空。

以下代码使用了多个进程,将 2 到 1000000 的大区间分为 4 个小区间,再分配给 4 个进程去分别计算小区间内共有多少素数,通过队列返回。主进程最后把每个进程统计的素数个数相加即是最终结果。

# encoding:utf8
import multiprocessing
from math import sqrt
from multi.single_thread_check_prime import check_prime
CHECK_NUMBERS = 1000000
NUM_PROCESSES = multiprocessing.cpu_count()
def worker(start, end, result_mq):
'''count prime numbers between start and end(exclusive)'''
total = 0
for n in range(start, end):
if check_prime(n):
total += 1
result_mq.put(total)
def divide_range(lower_end, upper_end, num_range):
'''divide a larger range into smaller ranges'''
step = int((upper_end - lower_end) / num_range)
ranges = []
subrange_upper = lower_end
while subrange_upper <= upper_end:
subrange_lowerend = subrange_upper
subrange_upper += step
if subrange_upper <= upper_end:
ranges.append((subrange_lowerend, subrange_upper))
continue
if subrange_lowerend < upper_end:
ranges.append((subrange_lowerend, upper_end))
return ranges
def run():
params = divide_range(2, CHECK_NUMBERS + 1, 4) # [(2, 250001), (250001, 500000), (500000, 749999), (749999, 999998), (999998, 1000001)]
result_mq = multiprocessing.Queue()
processes = []
for i in range(NUM_PROCESSES):
process = multiprocessing.Process(target=worker, args=list(params[i]) + [result_mq] )
processes.append(process)
process.start()
for process in processes:
process.join()
total = 0
for i in range(NUM_PROCESSES):
count = result_mq.get()
total += count
print(total)
return total
if __name__ == "__main__":
import timeit
print(timeit.timeit("run()", "from __main__ import run", number=1))

使用多进程后的输出结果为:

78498
1.6613719538839973

最终结果一致,约快了 2.9 倍。我的电脑 CPU 有 4 核,由于创建进程、进程间通信也都需要消耗资源,所以没法达到理想的 4 倍,但也已经是不错的提升了。

进程池

前面一个例子通过实例化 Process 的方法生成新的进程,再对每个进程调用 start 和 join,但其实通过进程池可以使代码更简洁。上面多进程的例子中,使用进程池后可以去除 Queue 队列的使用。首先在 worker 中去除消息队列 result_mq,直接返回结果如下:

def worker(sub_range):
'''count prime numbers between start and end(exclusive)'''
start, end = sub_range
total = 0
for n in range(start, end):
if check_prime(n):
total += 1
return total

这时在 run 函数里就可以用进程池的 map 方法,修改后的 run:

def run():
params = divide_range(2, CHECK_NUMBERS + 1, 4) # [(2, 250001), (250001, 500000), (500000, 749999), (749999, 999998), (999998, 1000001)]
pool = multiprocessing.Pool(processes=NUM_PROCESSES)
result = pool.map(worker, params)
total = sum(result)
print(total)
return total

由于不用在代码里逐个生成子进程,同时 map 方法可以直接返回结果,run 函数从原来的 19 行缩短为了 8 行。进程池不仅有 map 方法,还有 map_async, apply, apply_async 等方法,带 async 后缀的方法能够实现非阻塞调用,主进程不必等到子进程运行完毕才往下运行。比如上面的 run 函数可以修改成如下:

def run():
params = divide_range(2, CHECK_NUMBERS + 1, 4) # [(2, 250001), (250001, 500000), (500000, 749999), (749999, 999998), (999998, 1000001)]
pool = multiprocessing.Pool(processes=NUM_PROCESSES)
result = pool.map_async(worker, params)
pool.close()
# do something else here
...
pool.join()
total = sum(result.get())
print(total)
return total

管理者

管理者 Manager 可以存储需要在进程间共享的对象,其支持的类型包括 list, dict, Namespace, Lock, RLock, Semaphore, BoundedSemaphore, Condition, Event, Queue, Value, Array。相比于共用内存,Manager 可以让不同机器上的进程通过网络共享对象。Manager 的 register 方法还可以自定义新的类型或者可调用对象,具体使用见文档:Manager

sharedctypes

multiprocessing.sharedctypes 可以用来在共享内存中创建 c 类型数据,可以作为参数在创建子进程时传入。其中主要用到的有 Value, RawValue, Array, RawArray 。Value, Array 通过参数可以设置是否需要锁实现进程安全,如果对进程间同步没有要求使用 RawValue 和 RawArray 则有更高的运行效率。实例化使用如下:

n = Value('i', 7)
x = RawArray(‘h’, 7)
s = RawArray(‘i’, (9, 2, 8))

第一个参数指定类型,类型编码见array module,或者使用 ctypes 模块,如 ctypes.c_double,第二个参数为值,对于 Array 和 RawArray,当第二个参数为整型时则为数组的长度,数组元素初始化为0。对于 linux 操作系统,只要把 n, x, s 设置成全局变量,即可在子进程中使用,无需显式传参,但 window 操作系统则需要在实例化 Process 时传入参数。不可使用 Pool.map 把 n, x, s 作为参数传递,因为 map 使用的序列化而 n,x,s 不可序列化。具体使用见下一小节。

共享 numpy 数组

需要用到 numpy 时往往是数据量较大的场景,如果直接复制会造成大量内存浪费。共享 numpy 数组则是通过上面一节的 Array 实现,再用 numpy.frombuffer 以及 reshape 对共享的内存封装成 numpy 数组,代码如下:

# encoding:utf8
import ctypes
import os
import multiprocessing
import numpy as np
NUM_PROCESS = multiprocessing.cpu_count()
def worker(index):
main_nparray = np.frombuffer(shared_array_base, dtype=ctypes.c_double)
main_nparray = main_nparray.reshape(NUM_PROCESS, 10)
pid = os.getpid()
main_nparray[index, :] = pid
return pid
if __name__ == "__main__":
shared_array_base = multiprocessing.Array(
ctypes.c_double, NUM_PROCESS * 10, lock=False)
pool = multiprocessing.Pool(processes=NUM_PROCESS)
result = pool.map(worker, range(NUM_PROCESS))
main_nparray = np.frombuffer(shared_array_base, dtype=ctypes.c_double)
main_nparray = main_nparray.reshape(NUM_PROCESS, 10)
print main_nparray
mmap

mmap 把文件映射到内存,也可以用于进程间通信,可以像字符串或文件一样对其进行操作,操作较为简单。

import mmap
# write a simple example file
with open("hello.txt", "wb") as f:
f.write("Hello Python!\n")
with open("hello.txt", "r+b") as f:
# memory-map the file, size 0 means whole file
mm = mmap.mmap(f.fileno(), 0)
# read content via standard file methods
print mm.readline() # prints "Hello Python!"
# read content via slice notation
print mm[:5] # prints "Hello"
# update content using slice notation;
# note that new content must have same size
mm[6:] = " world!\n"
# ... and read again using standard file methods
mm.seek(0)
print mm.readline() # prints "Hello world!"
# close the map
mm.close()

以上例子为映射文件,对于进程间共享内存在 mmap 时第一个参数设置为 -1 实现匿名映射,只创建共享内存不映射到磁盘,见官网mmap。