Forkjoin框架是在JDK 7里面引入的,适用于将一个大的任务分成N个子任务并行来执行,然后合并每个子任务的结果并返回。来源于Doug Lea大神的forkjoin论文
如下分析基于JDK 8
1.fork/join框架设计
fork/join框架设计如下:
- 一个工作线程池:每个线程是一个标准的heavy thread(ForkJoinWorkerThread的子类),有属于自己的工作队列,执行里面的任务
- 经量级任务ForkJoinTask:所有fork/join的任务都是经量级的也执行类,ForkJoinTask的子类
- 任务队列:一个特殊的入队列和调度原则被工作线程用来管理和执行任务,通过提供的fork/join/isDone方法来触发任务执行,使用coInvoke来获取先fork再join多个任务
- ForkJoinWorkerThreadFactory:一个简单的ForkJoinWorkerThreadFactory来管理线程池,执行给定的任务
只需要关注2个参数:
- 一是线程数量,依赖于CPU的核心数量
- 二是设置合理的任务负载来提升并发性能,算法依赖
fork/join框架任务执行示意图:
Work-Stealing
fork/join框架的核心是轻量级的任务调度。
ForkJoinTask使用了work-stealing调度:
- 每个工作线程在自己的任务队列里维护可执行的任务
- 任务队列是双端队列,同时支持LIFO的push和pop操作,FIFO的poll操作
- 运行指定的线程A的任务生成的子任务,也会加入到线程A中工作队列中
- 工作线程按照LIFO的顺序通过pop方法来获取任务并执行
- 当一个工作线程没有任务可以执行的时候(对应的工作队列为空),会随机的从其它线程stealing任务,使用FIFO规则(意思是执行线程自己的工作列队是LIFO,从头部开始执行;窃取任务是FIFO,从尾部窃取)
- 当一个工作线程遇到join操作的时候,会执行其它任务,如果有,直到目标任务被通知已经完成。然后非阻塞的执行所有任务
- 当一个工作线程没有任务可以执行且窃取任务失败时,回退然后稍后重新尝试窃取。当所有工作线程都处于IDLE状态,则停止窃取并阻塞直到有新的task被加入
work-stealing任务执行示意图:
Deques
基本的数据结构是一个可以扩容的数组,有2个指针
- top 类似于以数组形式表示的栈顶,通过push和pop操作来改变
- base 类似于以数组形式表示的栈底,只能通过poll操作来改变
由于队列可能会被多个线程操作,使用了如下策略来解决同步和内存占用问题
- push和pop操作只能在当前所属的线程里面执行
- 使用Lock来保证每次只有一个线程执行take操作窃取任务
- pop和take操作只有当队列非空的时候才能执行(在代码实现里面,用是CAS来保证队列操作的原子性的,解决只有一个任务,被2个线程同时获取的场景)
top和base被定义为volatile,来保证可见性
在work-stealing框架里面,工作线程对完全不知道程序里面的同步要求,他们只简单的生成、push、pop、poll来管理任务队列状态以及执行任务。当没有足够的任务时,且无法成功窃取任务时,重试窃取任务的操作会导致线程变慢,因为加锁的原因
Java里面对于此种场景的处理工具比较弱,但是在一般使用情况下是可接受的。当一个线程窃取任务失败时,会降低线程的优先级,调用Thread.yield然后再重试,然后向ThreadGroup注册非激活状态,当所有其它线程都 变成非激活状态时,则都阻塞,等待额外新的任务
2.fork/join框架实现说明
2.1 ForkJoinTask
ForkJoinTask类定义如下
public abstract class ForkJoinTask<V> implements Future<V>, Serializable
代表执行在ForkJoinPool里面的一个任务,比正常的线程更经量,大量的任务及其子任务会被ForkJoinPool里面管理的少量线程持有并执行,代价是有很多用法限制
一般使用其子类来代表具体的任务
- CountedCompleter 将结果合并从任务分解计算的过程中分离出来,只要每个子任务能保存状态,那么通过合并逻辑,就可以在任何时候将结果合并
- RecursiveAction 没有返回值的递归任务
- RecursiveTask 有返回值的递归任务
必须实现compute接口,在该方法里面定义任务fork和join的逻辑
2.2 ForkJoinPool
ForkJoinPool 继承自AbstractExecutorService,与其它的ExecutorService主要的不同点是在工作队列里面使用了work-stealing算法,来保证每个线程的负载,充分发挥多核处理器的能力。
对外提供了commonPool()来获取ForkJoinPool对象,对于大部分应用场景,建议使用此方法来获取ForkJoinPool对象。
对于非fork/join任务,和ThreadPoolExecutor的使用方法类似,,支持ForkJoinTask/Runnable/Callable 3种对象,后面的2中都会转换成ForkJoinTask;对于fork/join任务,主要使用fork方法来分解任务,用join方法来合并计算结果
提交任务 | Call from non-fork/join clients | Call from within fork/join computations |
Arrange async execution | execute(ForkJoinTask) | ForkJoinTask.fork() |
Await and obtain result | invoke(ForkJoinTask) | ForkJoinTask.invoke() |
Arrange exec and obtain Future | submit(ForkJoinTask) | ForkJoinTask.fork() (ForkJoinTasks are Futures) |
里面的关键属性
// Instance fields
volatile long ctl; // main pool control
volatile int runState; // lockable status
final int config; // parallelism, mode
int indexSeed; // to generate worker index
volatile WorkQueue[] workQueues; // main registry
final ForkJoinWorkerThreadFactory factory;
final UncaughtExceptionHandler ueh; // per-worker UEH
final String workerNamePrefix; // to create worker name string
volatile AtomicLong stealCounter; // also used as sync monitor
以如下代码为例说明调用流程,完整代码见ForkJoinTest
Integer sum = ForkJoinPool.commonPool().invoke(new FaTask(list));
private static class FaTask extends RecursiveTask<Integer> {
// task 编号,方便打印的时候分析
public int index;
private static final int THRESHOLD = 3;
private List<Integer> mList;
private FaTask(List<Integer> list) {
mList = list;
index = taskIndex.getAndIncrement();
}
@Override
protected Integer compute() {
Integer result;
System.out.println(Thread.currentThread().getName() + " run " + this.toString());
if (mList.size() <= THRESHOLD) {
result = calc();
} else {
int mid = mList.size() / 2;
FaTask leftTask = new FaTask(mList.subList(0, mid));
FaTask rightTask = new FaTask(mList.subList(mid, mList.size()));
System.out.println(Thread.currentThread().getName() + " task = " + index + " split to task " + leftTask.index + " and task " + rightTask.index);
leftTask.fork();
rightTask.fork();
System.out.println(Thread.currentThread().getName() + " task = " + index + " wait for task " + leftTask.index + " finish ,status " + this.isDone());
Integer sumleft = leftTask.join();
System.out.println(Thread.currentThread().getName() + " task = " + index + " wait for task " + rightTask.index + " finish ,status " + this.isDone());
Integer sumright = rightTask.join();
result = sumleft+sumright;
}
System.out.println(Thread.currentThread().getName() + " finish task " + index + " ; result = " + result + ",status " + this.isDone());
return result;
}
运行如果:
main run task 0 :[1,2,3,4,5,6,7,8,9,]
main task = 0 split to task 1 and task 2
main task = 0 wait for task 1 finish ,status false
ForkJoinPool.commonPool-worker-9 run task 1 :[1,2,3,4,]
ForkJoinPool.commonPool-worker-2 run task 2 :[5,6,7,8,9,]
ForkJoinPool.commonPool-worker-9 task = 1 split to task 3 and task 4
ForkJoinPool.commonPool-worker-2 task = 2 split to task 5 and task 6
ForkJoinPool.commonPool-worker-9 task = 1 wait for task 3 finish ,status false
ForkJoinPool.commonPool-worker-11 run task 3 :[1,2,]
ForkJoinPool.commonPool-worker-4 run task 5 :[5,6,]
ForkJoinPool.commonPool-worker-13 run task 6 :[7,8,9,]
ForkJoinPool.commonPool-worker-2 task = 2 wait for task 5 finish ,status false
ForkJoinPool.commonPool-worker-6 run task 4 :[3,4,]
ForkJoinPool.commonPool-worker-6 finish task 4 ; result = 7,status false
ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false
ForkJoinPool.commonPool-worker-4 finish task 5 ; result = 11,status false
ForkJoinPool.commonPool-worker-9 task = 1 wait for task 4 finish ,status false
ForkJoinPool.commonPool-worker-13 finish task 6 ; result = 24,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false
ForkJoinPool.commonPool-worker-2 task = 2 wait for task 6 finish ,status false
main task = 0 wait for task 2 finish ,status false
ForkJoinPool.commonPool-worker-2 finish task 2 ; result = 35,status false
main finish task 0 ; result = 45,status false
Sum is = 45
2.3 WorkerQueue
WorkerQueue
// Instance fields
volatile int scanState; // versioned, <0: inactive; odd:scanning
int stackPred; // pool stack (ctl) predecessor
int nsteals; // number of steals
int hint; // randomization and stealer index hint
int config; // pool index and mode
volatile int qlock; // 1: locked, < 0: terminate; else 0
volatile int base; // index of next slot for poll
int top; // index of next slot for push
ForkJoinTask<?>[] array; // the elements (initially unallocated)
final ForkJoinPool pool; // the containing pool (may be null)
final ForkJoinWorkerThread owner; // owning thread or null if shared
volatile Thread parker; // == owner during call to park; else null
volatile ForkJoinTask<?> currentJoin; // task being joined in awaitJoin
volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer
通过ForkJoinPool.commonPool()方法来获取ForkJoinPool对象,默认并发线程是
Runtime.getRuntime().availableProcessors() - 1
主任务的调用链如下:
invoke -->
1–> externalPush --> externalSubmit --> new WorkerQueue and signalWork --> tryAddWorker --> createWorker --> runWorker 最后externalPush方法返回
2–> task.join --> doJoin --> doExec --> compute 进入我们定义的compute 方法 --> externalAwaitDone 等待其它任务完成后返回
提交任务是调用的invoke方法
public <T> T invoke(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task.join();
}
核心是externalPush方法,把任务加入工作队列里面,然后等待返回结果
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
// workQueues非空且能获取锁
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
externalSubmit(task);
}
使用invoke方法的时候,workQueues为空,走到externalSubmit方法里面,在里面会创建一个WorkQueue,然后尝试提交任务,提交成功后,调用signalWork
signalWork方法,如果工作线程很少的时候,尝试创建或激活一个工作线程。tryAddWorker --> createWorker ,在createWorker方法里面会创建一个ForkJoinWorkerThread工作线程,然后启动线程,最后externalPush方法返回。此时还没有开始执行任务,只是创建了一个线程并把任务添加到工作队列了
继续看invoke方法里面的task.join
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
继续调用doJoin方法,然后得到结果
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
// 上面的方法嵌套了3层三元运算符,看着比较晕,把方法稍微改造一下
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
s = status;
if (s < 0) {
return s;
}
t = Thread.currentThread();
int result;
if(t instanceof ForkJoinWorkerThread ) {
wt = (ForkJoinWorkerThread)t).workQueue;
// 尝试出队列并执行任务,doExec --> exec --> compute
if (wt.tryUnpush(this) && (s = doExec()) < 0 ) {
result = s;
} else {
result = wt.pool.awaitJoin(w, this, 0L) ;
}
} else {
result = externalAwaitDone();
}
return result;
}
此时线程的status为0(线程创建的初始化状态为0),线程只是创建了,状态还没有变化,Thread.currentThread()是一个ForkJoinWorkerThread对象,走tryUnpush和doExec,走定制的computer方法,然后会有如下打印
main task = 0 split to task 1 and task 2
再后面就是创建新的ForkJoinTask对象,调用对应的fork方法,由于当前线程是main,不是ForkJoinWorkerThread,所以走ForkJoinPool.common.externalPush,
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
流程就和上面的类似,创建一个的线程,重新走task0的流程
ForkJoinPool.commonPool-worker-2 run task 1 :[1,2,3,4,]
ForkJoinPool.commonPool-worker-2 task = 1 split to task 3 and task 4
如下打印之后,调用task.join方法
ForkJoinPool.commonPool-worker-2 task = 1 wait for task 3 finish ,status false
等待task3执行完,
ForkJoinPool.commonPool-worker-11 run task 3 :[1,2,]
task3执行完后,对应的join方法返回,task4的执行流程类似,只是在不同的线程里面
ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false
task3和task4 完成后,task1里面的2个子任务都从Join方法返回,task1完成,并返回执行结果
ForkJoinPool.commonPool-worker-6 finish task 4 ; result = 7,status false ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false
同理,最开始的任务task0分成task1和task2,等2个任务都完成后,task0返回最终的结果
ForkJoinPool.commonPool-worker-2 finish task 2 ; result = 35,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false
main finish task 0 ; result = 45,status false
Sum is = 45
【参考】