本文主要学习了多线程的分支合并框架。

部分内容来自以下博客:

https://segmentfault.com/a/1190000016781127

https://segmentfault.com/a/1190000016877931

1 简介

JDK1.7版本引入了一套Fork/Join框架。Fork/Join框架的基本思想就是将一个大任务分解(Fork)成一系列子任务,子任务可以继续往下分解,当多个不同的子任务都执行完成后,可以将它们各自的结果合并(Join)成一个大结果,最终合并成大任务的结果。

Fork/Join 框架要完成两件事情:

1)Fork:把一个复杂任务进行分拆

2)Join:把分拆任务的结果进行合并

Fork/Join框架的实现非常复杂,内部大量运用了位操作和无锁算法。

Fork/Join框架内部还涉及到三大核心组件:ForkJoinPool(线程池)、ForkJoinTask(任务)、ForkJoinWorkerThread(工作线程),外加WorkQueue(任务队列)。

2 类和接口

2.1 ForkJoinPool

ForkJoinPool是分支合并池,类似于线程池ThreadPoolExecutor,同样是ExecutorService接口的一个实现类。

ForkJoinPool类的实现:

1 public class ForkJoinPool extends AbstractExecutorService {

在ForkJoinPool类中提供了三个构造方法:

1 public ForkJoinPool();
2 public ForkJoinPool(int parallelism);
3 public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode);

最终调用的是下面这个私有构造器:

1 private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix);

其参数含义如下:

parallelism:并行级别,默认值为CPU核心数,ForkJoinPool里工作线程数量与该参数有关,但它不表示最大线程数。

factory:工作线程工厂,默认是DefaultForkJoinWorkerThreadFactory,其实就是用来创建ForkJoinWorkerThread工作线程对象。

handler:异常处理器。

mode:调度模式,true表示FIFO_QUEUE,false表示LIFO_QUEUE。

workerNamePrefix:工作线程的名称前缀。

2.2 ForkJoinTask

ForkJoinTask是Future接口的抽象实现类,提供了用于分解任务的fork()方法和用于合并任务的join()方法。

在ThreadPoolExecutor类中,使用线程池执行任务调用的execute()方法中要求传入Runnable接口的实例。但是在ForkJoinPool类中,除了可以传入Runnable接口的实例外,还可以传入ForkJoinTask抽象类的实例,并且传入Runnable接口的实例也会被适配为ForkJoinTask抽象类的实例。

2.3 RecursiveTask

通常情况下使用ForkJoinTask抽象类的实例,并不需要直接继承ForkJoinTask类,只需要继承其子类:

1)RecursiveAction:用于没有返回结果的任务

2)RecursiveTask:用于有返回结果的任务

其中,最常用的还是RecursiveTask类。

2.4 ForkJoinWorkerThread

ForkJoinWorkerThread类是Thread的子类,作为线程池中的工作线程执行任务,其内部维护了一个WorkerQueue类型的双向任务队列。

工作线程在执行任务时,优先处理自身任务队列中的任务(FIFO或者LIFO),当自身队列中的任务为空时,会窃取其他任务队列中的任务(FIFO)。

2.5 WorkerQueue

WorkerQueue类是ForkJoinPool类的一个内部类,代表存储ForkJoinTask实例的双端队列。

在ForkJoinPool类的私有构造方法中,有一个int类型的mode参数,其取值如下:

1 static final int LIFO_QUEUE = 0;
2 static final int FIFO_QUEUE = 1 << 16;

当入参为LIFO_QUEUE时,表示同步,对于工作线程(Worker)自身队列中的任务,采用后进先出(LIFO)的方式执行。

当入参为FIFO_QUEUE时,表示异步,对于工作线程(Worker)自身队列中的任务,采用先进先出(FIFO)的方式执行。

3 实现原理

3.1 提交任务

使用ForkJoinPool的submit方法提交任务得到ForkJoinTask对象:

1 public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
2     if (task == null)
3         throw new NullPointerException();
4     externalPush(task);
5     return task;
6 }

继续查看externalPush方法:

 1 final void externalPush(ForkJoinTask<?> task) {
 2     WorkQueue[] ws; WorkQueue q; int m;
 3     int r = ThreadLocalRandom.getProbe();
 4     int rs = runState;
 5     if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
 6         (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
 7         U.compareAndSwapInt(q, QLOCK, 0, 1)) {
 8         ForkJoinTask<?>[] a; int am, n, s;
 9         if ((a = q.array) != null &&
10             (am = a.length - 1) > (n = (s = q.top) - q.base)) {
11             int j = ((am & s) << ASHIFT) + ABASE;
12             U.putOrderedObject(a, j, task);
13             U.putOrderedInt(q, QTOP, s + 1);
14             U.putIntVolatile(q, QLOCK, 0);
15             if (n <= 1)
16                 signalWork(ws, q);
17             return;
18         }
19         U.compareAndSwapInt(q, QLOCK, 1, 0);
20     }
21     externalSubmit(task);
22 }

该方法包含两个部分:

1)尝试将任务添加到任务队列,添加后则创建或激活一个工作线程,在此过程中使用了CAS保证线程安全。

2)添加队列失败,则调用externalSubmit方法初始化队列,并将任务加入到队列。

3.2 分解任务

3.2.1 创建或唤醒工作线程

调用ForkJoinTask的fork方法完成任务分解:

1 public final ForkJoinTask<V> fork() {
2     Thread t;
3     if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)// 调用线程为工作线程
4         ((ForkJoinWorkerThread)t).workQueue.push(this);// 将任务添加到自身队列
5     else
6         ForkJoinPool.common.externalPush(this);// 调用ForkJoinPool的externalPush方法
7     return this;
8 }

该方法包含两个部分:

1)调用线程为工作线程,将任务添加到自身队列。

2)调用线程为其他外部线程,继续调用ForkJoinPool的externalPush方法,尝试将任务添加到任务队列并激活工作线程。

继续查看push方法,添加任务到自身队列:

 1 final void push(ForkJoinTask<?> task) {
 2     ForkJoinTask<?>[] a; ForkJoinPool p;
 3     int b = base, s = top, n;
 4     if ((a = array) != null) {    // ignore if queue removed
 5         int m = a.length - 1;     // fenced write for task visibility
 6         U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
 7         U.putOrderedInt(this, QTOP, s + 1);
 8         if ((n = s - b) <= 1) {
 9             if ((p = pool) != null)
10                 p.signalWork(p.workQueues, this);// 唤醒或创建工作线程
11         }
12         else if (n >= m)
13             growArray();// 扩容
14     }
15 }

该方法包含两个部分:

1)判断是否需要扩容,不需要扩容则唤醒或创建工作线程。

2)需要扩容,则进行扩容操作。

继续查看signalWork方法,创建或唤醒工作线程:

 1 final void signalWork(WorkQueue[] ws, WorkQueue q) {
 2     long c; int sp, i; WorkQueue v; Thread p;
 3     while ((c = ctl) < 0L) {                       // too few active
 4         if ((sp = (int)c) == 0) {                  // 没有空闲工作进程
 5             if ((c & ADD_WORKER) != 0L)            // 工作进程太少
 6                 tryAddWorker(c);// 增加工作进程
 7             break;
 8         }
 9         // 有工作进程,唤醒
10         if (ws == null)                            // unstarted/terminated
11             break;
12         if (ws.length <= (i = sp & SMASK))         // terminated
13             break;
14         if ((v = ws[i]) == null)                   // terminating
15             break;
16         int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
17         int d = sp - v.scanState;                  // screen CAS
18         long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
19         if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
20             v.scanState = vs;                      // activate v
21             if ((p = v.parker) != null)
22                 U.unpark(p);
23             break;
24         }
25         if (q != null && q.base == q.top)          // no more work
26             break;
27     }
28 }

继续查看tryAddWorker方法:

 1 private void tryAddWorker(long c) {
 2     boolean add = false;
 3     do {
 4         // 设置活跃工作线程数和总工作线程数
 5         long nc = ((AC_MASK & (c + AC_UNIT)) |
 6                    (TC_MASK & (c + TC_UNIT)));
 7         if (ctl == c) {
 8             int rs, stop;                 // check if terminating
 9             if ((stop = (rs = lockRunState()) & STOP) == 0)
10                 add = U.compareAndSwapLong(this, CTL, c, nc);
11             unlockRunState(rs, rs & ~RSLOCK);
12             if (stop != 0)
13                 break;
14             if (add) {
15                 // 创建工作线程
16                 createWorker();
17                 break;
18             }
19         }
20     } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
21 }

继续查看createWorker方法:

 1 private boolean createWorker() {
 2     ForkJoinWorkerThreadFactory fac = factory;
 3     Throwable ex = null;
 4     ForkJoinWorkerThread wt = null;
 5     try {
 6         // 使用线程池工厂创建线程
 7         if (fac != null && (wt = fac.newThread(this)) != null) {
 8             // 启动线程
 9             wt.start();
10             return true;
11         }
12     } catch (Throwable rex) {
13         ex = rex;
14     }
15     // 出现异常,注销该工作线程
16     deregisterWorker(wt, ex);
17     return false;
18 }

3.2.2 启动任务

ForkJoinWorkerThread在执行start方法后,会执行run方法:

 1 public void run() {
 2     if (workQueue.array == null) { // only run once
 3         Throwable exception = null;
 4         try {
 5             onStart();
 6             pool.runWorker(workQueue);
 7         } catch (Throwable ex) {
 8             exception = ex;
 9         } finally {
10             try {
11                 onTermination(exception);
12             } catch (Throwable ex) {
13                 if (exception == null)
14                     exception = ex;
15             } finally {
16                 pool.deregisterWorker(this, exception);
17             }
18         }
19     }
20 }

在run方法内部调用了ForkJoinPool对象的runWorker方法:

 1 final void runWorker(WorkQueue w) {
 2     w.growArray();                   // 初始化任务队列
 3     int seed = w.hint;               // initially holds randomization hint
 4     int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
 5     for (ForkJoinTask<?> t;;) {
 6         if ((t = scan(w, r)) != null)// 尝试获取任务
 7             w.runTask(t);// 执行任务
 8         else if (!awaitWork(w, r))// 获取失败,加入等待任务队列
 9             break;// 等待失败,跳出方法并注销工作线程
10         r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
11     }
12 }

3.2.3 窃取任务

使用scan方法窃取任务:

 1 private ForkJoinTask<?> scan(WorkQueue w, int r) {
 2     WorkQueue[] ws; int m;
 3     if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
 4         int ss = w.scanState;                     // initially non-negative
 5         for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
 6             WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
 7             int b, n; long c;
 8             if ((q = ws[k]) != null) {// 定位任务队列
 9                 if ((n = (b = q.base) - q.top) < 0 &&
10                     (a = q.array) != null) {      // non-empty
11                     long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
12                     if ((t = ((ForkJoinTask<?>)
13                               U.getObjectVolatile(a, i))) != null &&
14                         q.base == b) {
15                         if (ss >= 0) {
16                             if (U.compareAndSwapObject(a, i, t, null)) {
17                                 q.base = b + 1;
18                                 if (n < -1)       // signal others
19                                     signalWork(ws, q);// 创建获唤醒工作线程执行任务
20                                 return t;
21                             }
22                         }
23                         else if (oldSum == 0 &&   // try to activate
24                                  w.scanState < 0)
25                             tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);// 唤醒栈顶工作线程
26                     }
27                     if (ss < 0)                   // refresh
28                         ss = w.scanState;
29                     r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
30                     origin = k = r & m;           // move and rescan
31                     oldSum = checkSum = 0;
32                     continue;
33                 }
34                 checkSum += b;
35             }
36             // 已扫描全部工作线程,但并未找到任务
37             if ((k = (k + 1) & m) == origin) {    // continue until stable
38                 if ((ss >= 0 || (ss == (ss = w.scanState))) &&
39                     oldSum == (oldSum = checkSum)) {
40                     if (ss < 0 || w.qlock < 0)    // already inactive
41                         break;
42                     int ns = ss | INACTIVE;       // 尝试对当前工作线程灭活
43                     long nc = ((SP_MASK & ns) |
44                                (UC_MASK & ((c = ctl) - AC_UNIT)));
45                     w.stackPred = (int)c;         // hold prev stack top
46                     U.putInt(w, QSCANSTATE, ns);
47                     if (U.compareAndSwapLong(this, CTL, c, nc))
48                         ss = ns;
49                     else
50                         w.scanState = ss;         // back out
51                 }
52                 checkSum = 0;
53             }
54         }
55     }
56     return null;
57 }

3.2.4 执行任务

窃取到任务后,调用runTask方法执行任务:

 1 final void runTask(ForkJoinTask<?> task) {
 2     if (task != null) {
 3         scanState &= ~SCANNING; // mark as busy
 4         (currentSteal = task).doExec();// 执行任务
 5         U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
 6         execLocalTasks();// 执行本地任务
 7         ForkJoinWorkerThread thread = owner;
 8         if (++nsteals < 0)      // collect on overflow
 9             transferStealCount(pool);// 增加窃取任务数
10         scanState |= SCANNING;
11         if (thread != null)
12             thread.afterTopLevelExec();// 执行钩子函数
13     }
14 }

3.2.5 阻塞等待

如何未窃取到任务,会调用awaitWork方法等待获取任务:

 1 private boolean awaitWork(WorkQueue w, int r) {
 2     if (w == null || w.qlock < 0)                 // w is terminating
 3         return false;
 4     for (int pred = w.stackPred, spins = SPINS, ss;;) {
 5         if ((ss = w.scanState) >= 0)
 6             break;
 7         else if (spins > 0) {
 8             r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
 9             if (r >= 0 && --spins == 0) {         // randomize spins
10                 WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
11                 if (pred != 0 && (ws = workQueues) != null &&
12                     (j = pred & SMASK) < ws.length &&
13                     (v = ws[j]) != null &&        // see if pred parking
14                     (v.parker == null || v.scanState >= 0))
15                     spins = SPINS;                // continue spinning
16             }
17         }
18         else if (w.qlock < 0)                     // recheck after spins
19             return false;
20         else if (!Thread.interrupted()) {
21             long c, prevctl, parkTime, deadline;
22             int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
23             if ((ac <= 0 && tryTerminate(false, false)) ||
24                 (runState & STOP) != 0)           // pool terminating
25                 return false;
26             if (ac <= 0 && ss == (int)c) {        // is last waiter
27                 prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
28                 int t = (short)(c >>> TC_SHIFT);  // shrink excess spares
29                 if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
30                     return false;                 // else use timed wait
31                 parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
32                 deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
33             }
34             else
35                 prevctl = parkTime = deadline = 0L;
36             Thread wt = Thread.currentThread();
37             U.putObject(wt, PARKBLOCKER, this);   // emulate LockSupport
38             w.parker = wt;
39             if (w.scanState < 0 && ctl == c)      // recheck before park
40                 U.park(false, parkTime);
41             U.putOrderedObject(w, QPARKER, null);
42             U.putObject(wt, PARKBLOCKER, null);
43             if (w.scanState >= 0)
44                 break;
45             if (parkTime != 0L && ctl == c &&
46                 deadline - System.nanoTime() <= 0L &&
47                 U.compareAndSwapLong(this, CTL, c, prevctl))
48                 return false;                     // shrink pool
49         }
50     }
51     return true;
52 }

3.3 合并任务

使用ForkJoinTask的join方法可以获取任务的执行结果:

1 public final V join() {
2     int s;
3     if ((s = doJoin() & DONE_MASK) != NORMAL)
4         reportException(s);
5     return getRawResult();
6 }

查看doJoin方法:

1 private int doJoin() {
2     int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
3     return (s = status) < 0 ? s :
4         ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
5         (w = (wt = (ForkJoinWorkerThread)t).workQueue).
6         tryUnpush(this) && (s = doExec()) < 0 ? s :
7         wt.pool.awaitJoin(w, this, 0L) :
8         externalAwaitDone();
9 }
4 使用

4.1 计算多个整数的和

任务类定义,因为需要返回结果,所以继承RecursiveTask,并覆写compute方法。

任务的拆分通过ForkJoinTask的fork方法执行,join方法用于等待任务执行后返回。

 1 class SumTask extends RecursiveTask<Integer> {
 2     private static final int THRESHOLD = 10;// 拆分阈值
 3     private int begin;// 拆分开始值
 4     private int end;// 拆分结束值
 5     public SumTask(int begin, int end) {
 6         this.begin = begin;
 7         this.end = end;
 8     }
 9     @Override
10     protected Integer compute() {
11         Integer value = 0;
12         if (end - begin <= THRESHOLD) {// 小于阈值,直接计算
13             for (int i = begin; i <= end; i++) {
14                 value += i;
15             }
16         } else {// 大于阈值,递归计算
17             int middle = (begin + end) / 2;
18             SumTask beginTask = new SumTask(begin, middle);
19             SumTask endTask = new SumTask(middle + 1, end);
20             beginTask.fork();
21             endTask.fork();
22             value = beginTask.join() + endTask.join();
23         }
24         return value;
25     }
26 }
27 public class DemoTest {
28     public static void main(String[] args) {
29         SumTask sumTask = new SumTask(1, 100);
30         ForkJoinPool pool = new ForkJoinPool();
31         try {
32             ForkJoinTask<Integer> task = pool.submit(sumTask);
33             System.out.println(task.get());
34         } catch (Exception e) {
35             e.printStackTrace();
36         } finally {
37             pool.shutdown();
38         }
39     }
40 }

最终结果是5050。