Forkjoin框架是在JDK 7里面引入的,适用于将一个大的任务分成N个子任务并行来执行,然后合并每个子任务的结果并返回。来源于Doug Lea大神forkjoin论文

如下分析基于JDK 8

1.fork/join框架设计

fork/join框架设计如下:

  • 一个工作线程池:每个线程是一个标准的heavy thread(ForkJoinWorkerThread的子类),有属于自己的工作队列,执行里面的任务
  • 经量级任务ForkJoinTask:所有fork/join的任务都是经量级的也执行类,ForkJoinTask的子类
  • 任务队列:一个特殊的入队列和调度原则被工作线程用来管理和执行任务,通过提供的fork/join/isDone方法来触发任务执行,使用coInvoke来获取先fork再join多个任务
  • ForkJoinWorkerThreadFactory:一个简单的ForkJoinWorkerThreadFactory来管理线程池,执行给定的任务
    只需要关注2个参数:
  • 一是线程数量,依赖于CPU的核心数量
  • 二是设置合理的任务负载来提升并发性能,算法依赖

fork/join框架任务执行示意图:

java 多并发框架 java常用并发框架_forkjointask

Work-Stealing

fork/join框架的核心是轻量级的任务调度。

ForkJoinTask使用了work-stealing调度:

  • 每个工作线程在自己的任务队列里维护可执行的任务
  • 任务队列是双端队列,同时支持LIFO的push和pop操作,FIFO的poll操作
  • 运行指定的线程A的任务生成的子任务,也会加入到线程A中工作队列中
  • 工作线程按照LIFO的顺序通过pop方法来获取任务并执行
  • 当一个工作线程没有任务可以执行的时候(对应的工作队列为空),会随机的从其它线程stealing任务,使用FIFO规则(意思是执行线程自己的工作列队是LIFO,从头部开始执行;窃取任务是FIFO,从尾部窃取)
  • 当一个工作线程遇到join操作的时候,会执行其它任务,如果有,直到目标任务被通知已经完成。然后非阻塞的执行所有任务
  • 当一个工作线程没有任务可以执行且窃取任务失败时,回退然后稍后重新尝试窃取。当所有工作线程都处于IDLE状态,则停止窃取并阻塞直到有新的task被加入

work-stealing任务执行示意图:

java 多并发框架 java常用并发框架_forkjointask_02

Deques

基本的数据结构是一个可以扩容的数组,有2个指针

  • top 类似于以数组形式表示的栈顶,通过push和pop操作来改变
  • base 类似于以数组形式表示的栈底,只能通过poll操作来改变

由于队列可能会被多个线程操作,使用了如下策略来解决同步和内存占用问题

  • push和pop操作只能在当前所属的线程里面执行
  • 使用Lock来保证每次只有一个线程执行take操作窃取任务
  • pop和take操作只有当队列非空的时候才能执行(在代码实现里面,用是CAS来保证队列操作的原子性的,解决只有一个任务,被2个线程同时获取的场景)

top和base被定义为volatile,来保证可见性

在work-stealing框架里面,工作线程对完全不知道程序里面的同步要求,他们只简单的生成、push、pop、poll来管理任务队列状态以及执行任务。当没有足够的任务时,且无法成功窃取任务时,重试窃取任务的操作会导致线程变慢,因为加锁的原因

Java里面对于此种场景的处理工具比较弱,但是在一般使用情况下是可接受的。当一个线程窃取任务失败时,会降低线程的优先级,调用Thread.yield然后再重试,然后向ThreadGroup注册非激活状态,当所有其它线程都 变成非激活状态时,则都阻塞,等待额外新的任务

2.fork/join框架实现说明

2.1 ForkJoinTask

ForkJoinTask类定义如下

public abstract class ForkJoinTask<V> implements Future<V>, Serializable

代表执行在ForkJoinPool里面的一个任务,比正常的线程更经量,大量的任务及其子任务会被ForkJoinPool里面管理的少量线程持有并执行,代价是有很多用法限制

一般使用其子类来代表具体的任务

  • CountedCompleter 将结果合并从任务分解计算的过程中分离出来,只要每个子任务能保存状态,那么通过合并逻辑,就可以在任何时候将结果合并
  • RecursiveAction 没有返回值的递归任务
  • RecursiveTask 有返回值的递归任务

必须实现compute接口,在该方法里面定义任务fork和join的逻辑

2.2 ForkJoinPool

ForkJoinPool 继承自AbstractExecutorService,与其它的ExecutorService主要的不同点是在工作队列里面使用了work-stealing算法,来保证每个线程的负载,充分发挥多核处理器的能力。

对外提供了commonPool()来获取ForkJoinPool对象,对于大部分应用场景,建议使用此方法来获取ForkJoinPool对象。

对于非fork/join任务,和ThreadPoolExecutor的使用方法类似,,支持ForkJoinTask/Runnable/Callable 3种对象,后面的2中都会转换成ForkJoinTask;对于fork/join任务,主要使用fork方法来分解任务,用join方法来合并计算结果

提交任务

Call from non-fork/join clients

Call from within fork/join computations

Arrange async execution

execute(ForkJoinTask)

ForkJoinTask.fork()

Await and obtain result

invoke(ForkJoinTask)

ForkJoinTask.invoke()

Arrange exec and obtain Future

submit(ForkJoinTask)

ForkJoinTask.fork() (ForkJoinTasks are Futures)

里面的关键属性

// Instance fields
    volatile long ctl;                   // main pool control
    volatile int runState;               // lockable status
    final int config;                    // parallelism, mode
    int indexSeed;                       // to generate worker index
    volatile WorkQueue[] workQueues;     // main registry
    final ForkJoinWorkerThreadFactory factory;
    final UncaughtExceptionHandler ueh;  // per-worker UEH
    final String workerNamePrefix;       // to create worker name string
    volatile AtomicLong stealCounter;    // also used as sync monitor

以如下代码为例说明调用流程,完整代码见ForkJoinTest

Integer sum = ForkJoinPool.commonPool().invoke(new FaTask(list));

    private static class FaTask extends RecursiveTask<Integer> {

        // task 编号,方便打印的时候分析
        public int index;

        private static final int THRESHOLD = 3;
        private List<Integer> mList;

        private FaTask(List<Integer> list) {
            mList = list;
            index = taskIndex.getAndIncrement();
        }

        @Override
        protected Integer compute() {
            Integer result;
            System.out.println(Thread.currentThread().getName() + " run " + this.toString());

            if (mList.size() <= THRESHOLD) {
                result = calc();
            } else {
                int mid = mList.size() / 2;

                FaTask leftTask = new FaTask(mList.subList(0, mid));
                FaTask rightTask = new FaTask(mList.subList(mid, mList.size()));

                System.out.println(Thread.currentThread().getName() + " task = " + index + " split to task " + leftTask.index + " and task " + rightTask.index);

                leftTask.fork();
                rightTask.fork();

                System.out.println(Thread.currentThread().getName() + " task = " + index + " wait for task " + leftTask.index + " finish ,status " + this.isDone());
                Integer sumleft = leftTask.join();
                System.out.println(Thread.currentThread().getName() + " task = " + index + " wait for task " + rightTask.index + " finish  ,status " + this.isDone());
                Integer sumright = rightTask.join();

                result = sumleft+sumright;
            }

            System.out.println(Thread.currentThread().getName() + " finish task " + index + " ; result = " + result + ",status " + this.isDone());

            return result;

        }

运行如果:

main run task 0 :[1,2,3,4,5,6,7,8,9,]
main task = 0 split to task 1 and task 2
main task = 0 wait for task 1 finish ,status false
ForkJoinPool.commonPool-worker-9 run task 1 :[1,2,3,4,]
ForkJoinPool.commonPool-worker-2 run task 2 :[5,6,7,8,9,]
ForkJoinPool.commonPool-worker-9 task = 1 split to task 3 and task 4
ForkJoinPool.commonPool-worker-2 task = 2 split to task 5 and task 6
ForkJoinPool.commonPool-worker-9 task = 1 wait for task 3 finish ,status false
ForkJoinPool.commonPool-worker-11 run task 3 :[1,2,]
ForkJoinPool.commonPool-worker-4 run task 5 :[5,6,]
ForkJoinPool.commonPool-worker-13 run task 6 :[7,8,9,]
ForkJoinPool.commonPool-worker-2 task = 2 wait for task 5 finish ,status false
ForkJoinPool.commonPool-worker-6 run task 4 :[3,4,]
ForkJoinPool.commonPool-worker-6 finish task 4 ; result = 7,status false
ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false
ForkJoinPool.commonPool-worker-4 finish task 5 ; result = 11,status false
ForkJoinPool.commonPool-worker-9 task = 1 wait for task 4 finish  ,status false
ForkJoinPool.commonPool-worker-13 finish task 6 ; result = 24,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false
ForkJoinPool.commonPool-worker-2 task = 2 wait for task 6 finish  ,status false
main task = 0 wait for task 2 finish  ,status false
ForkJoinPool.commonPool-worker-2 finish task 2 ; result = 35,status false
main finish task 0 ; result = 45,status false
Sum is = 45
2.3 WorkerQueue

WorkerQueue

// Instance fields
        volatile int scanState;    // versioned, <0: inactive; odd:scanning
        int stackPred;             // pool stack (ctl) predecessor
        int nsteals;               // number of steals
        int hint;                  // randomization and stealer index hint
        int config;                // pool index and mode
        volatile int qlock;        // 1: locked, < 0: terminate; else 0
        volatile int base;         // index of next slot for poll
        int top;                   // index of next slot for push
        ForkJoinTask<?>[] array;   // the elements (initially unallocated)
        final ForkJoinPool pool;   // the containing pool (may be null)
        final ForkJoinWorkerThread owner; // owning thread or null if shared
        volatile Thread parker;    // == owner during call to park; else null
        volatile ForkJoinTask<?> currentJoin;  // task being joined in awaitJoin
        volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer

通过ForkJoinPool.commonPool()方法来获取ForkJoinPool对象,默认并发线程是

Runtime.getRuntime().availableProcessors() - 1

主任务的调用链如下:

invoke -->
1–> externalPush --> externalSubmit --> new WorkerQueue and signalWork --> tryAddWorker --> createWorker --> runWorker 最后externalPush方法返回
2–> task.join --> doJoin --> doExec --> compute 进入我们定义的compute 方法 --> externalAwaitDone 等待其它任务完成后返回

提交任务是调用的invoke方法

public <T> T invoke(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task.join();
    }

核心是externalPush方法,把任务加入工作队列里面,然后等待返回结果

final void externalPush(ForkJoinTask<?> task) {
        WorkQueue[] ws; WorkQueue q; int m;
        int r = ThreadLocalRandom.getProbe();
        int rs = runState;
        // workQueues非空且能获取锁
        if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {
            ForkJoinTask<?>[] a; int am, n, s;
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                int j = ((am & s) << ASHIFT) + ABASE;
                U.putOrderedObject(a, j, task);
                U.putOrderedInt(q, QTOP, s + 1);
                U.putIntVolatile(q, QLOCK, 0);
                if (n <= 1)
                    signalWork(ws, q);
                return;
            }
            U.compareAndSwapInt(q, QLOCK, 1, 0);
        }
        externalSubmit(task);
    }

使用invoke方法的时候,workQueues为空,走到externalSubmit方法里面,在里面会创建一个WorkQueue,然后尝试提交任务,提交成功后,调用signalWork

signalWork方法,如果工作线程很少的时候,尝试创建或激活一个工作线程。tryAddWorker --> createWorker ,在createWorker方法里面会创建一个ForkJoinWorkerThread工作线程,然后启动线程,最后externalPush方法返回。此时还没有开始执行任务,只是创建了一个线程并把任务添加到工作队列了

继续看invoke方法里面的task.join

public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

继续调用doJoin方法,然后得到结果

private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }
    
    // 上面的方法嵌套了3层三元运算符,看着比较晕,把方法稍微改造一下
    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;

        s = status;
        if (s < 0) {
            return s;
        }
        t =  Thread.currentThread();

        int result;
        if(t instanceof ForkJoinWorkerThread ) {
            wt = (ForkJoinWorkerThread)t).workQueue;
            // 尝试出队列并执行任务,doExec --> exec --> compute 
            if (wt.tryUnpush(this) && (s = doExec()) < 0 ) {
                result = s;
            } else {
                result = wt.pool.awaitJoin(w, this, 0L) ;
            }
        } else {
            result = externalAwaitDone();
        }
        
        return result;
    }

此时线程的status为0(线程创建的初始化状态为0),线程只是创建了,状态还没有变化,Thread.currentThread()是一个ForkJoinWorkerThread对象,走tryUnpush和doExec,走定制的computer方法,然后会有如下打印

main task = 0 split to task 1 and task 2

再后面就是创建新的ForkJoinTask对象,调用对应的fork方法,由于当前线程是main,不是ForkJoinWorkerThread,所以走ForkJoinPool.common.externalPush,

public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

流程就和上面的类似,创建一个的线程,重新走task0的流程

ForkJoinPool.commonPool-worker-2 run task 1 :[1,2,3,4,]
ForkJoinPool.commonPool-worker-2 task = 1 split to task 3 and task 4

如下打印之后,调用task.join方法

ForkJoinPool.commonPool-worker-2 task = 1 wait for task 3 finish ,status false

等待task3执行完,

ForkJoinPool.commonPool-worker-11 run task 3 :[1,2,]

task3执行完后,对应的join方法返回,task4的执行流程类似,只是在不同的线程里面

ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false

task3和task4 完成后,task1里面的2个子任务都从Join方法返回,task1完成,并返回执行结果

ForkJoinPool.commonPool-worker-6 finish task 4 ; result = 7,status false ForkJoinPool.commonPool-worker-11 finish task 3 ; result = 3,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false

同理,最开始的任务task0分成task1和task2,等2个任务都完成后,task0返回最终的结果

ForkJoinPool.commonPool-worker-2 finish task 2 ; result = 35,status false
ForkJoinPool.commonPool-worker-9 finish task 1 ; result = 10,status false
main finish task 0 ; result = 45,status false
Sum is = 45

【参考】