Java多线程(十三) Fork / Join

本篇博客是学习Fork / Join框架后的总结笔记,知识点来源于《Java并发编程的艺术》一书。


  • Java多线程(十三) Fork / Join
  • Fork / Join 框架
  • 工作窃取算法
  • Fork / Join 设计
  • 使用 Fork / Join


Fork / Join 框架

Fork/Join框架是Java 7提供的一个用于并行执行任务的框架,是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。这种思想类似于分而治之的思想。我们再通过Fork和Join这两个单词来理解一下Fork/Join框架。Fork就是把一个大任务切分为若干子任务并行的执行,Join就是合并这些子任务的执行结果,最后得到这个大任务的结果。比如计算1+2+…+1000,可以分割成10个子任务,每个子任务分别对1000个数进行求和,最终汇总这10个子任务的结果。

fork介绍 java fork函数 java_fork介绍 java

工作窃取算法

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。如果我们需要做一个比较大的任务,可以把这个任务分割为若千互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列对应。比如A线程负责处理A队列里的任务。但是,有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

fork介绍 java fork函数 java_java_02


工作窃取算法的优点很明显,它能够充分利用线程进行并行计算,减少线程之间的竞争,但是它也是存在缺点的,在某些情况下使用工作窃取算法还是存在竞争,例如,在双端队列中只存在一个任务。并且因为该算法会创建多个线程和多个双端队列,因此会带来更多的系统资源消耗。

Fork / Join 设计

通过上面搞明白了 Fork / Join 框架的原理后,就能很清楚它需要两个步骤了:

  1. 分割任务Fork:把大任务分割成子任务,如果分割的子任务还是很大,可以继续分割,直到子任务足够下
  2. 执行结果并合并结果Join:分割的子任务被存储在双端队列中,然后启动线程分别从双端队列获取任务执行。子任务执行完的结果都统一存储在一个队列中,启动一个线程从队列中拿数据,然后合并这些数据。

Fork / Join框架使用两个类完成上面两件事:

  1. ForkJoinTask:要使用ForkJoin框架就首先需要创建一个ForkJoin任务。它提供在任务中执行fork() 和 join()操作的机制,我们通常只需要继承他的子类即可,Fork/Join框架提供了以下两个子类:
RecursiveTask:用于没有返回结果的任务 	
RecursiveAction:用于有返回结果的任务
  1. ForkJoinPool:ForkJoinTask需要通过ForkJoinPool来执行。

使用 Fork / Join

让我们通过一个简单的例子来使用ForkJoin框架:计算1+2+…+1000000000的结果。使用ForkJoin框架首先要考虑到的是如何分割任务,如果希望每个子任务最多执行1000个数的相加,那么我们设置分割的阈值是1000。因为是有结果的任务,所以必须继承Recursive Task,实现代码如下。

class MyTask extends RecursiveTask<Long> {


    Long start;
    Long end;
    Long threshold = 1000000000l;

    public MyTask(Long start, Long end)
    {
        this.start = start;
        this.end = end;
    }
    @Override
    protected Long compute() {
        if((end - start)<threshold)
        {
            Long sum = 0l;
            for (Long i=start;i<=end;i++)
            {
                sum+=i;
            }
            return sum;
        }
        else {
            Long middle = (start + end) / 2;
            MyTask left = new MyTask(start, middle);
            MyTask right = new MyTask(middle + 1, end);
            left.fork();
            right.fork();
            Long leftResult = left.join();
            Long rightResult = right.join();
            return leftResult + rightResult;
        }
    }
}

public class MyForkJoin {

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        MyTask task = new MyTask(1l,10_0000_0000l);
        long startTime = System.currentTimeMillis();
        Future<Long> result = forkJoinPool.submit(task);
        System.out.println(result.get());
        long endTime = System.currentTimeMillis();
        System.out.println("运行时间:"+(endTime-startTime)+"ms");
    }
}

fork介绍 java fork函数 java_子任务_03