有一些排序算法如归并排序、快速排序等可以分解为子问题的算法是可以使用多线程来加速排序的,之前做了个小实验,测试了下自己写的MergeSort::parallelSort、QuickSort::parallelSort以及Arrays::sort、Arrays::parallelSort类库排序方法的速度,随机生成1千万个数据用于排序,使用-Xmx、-Xms设置了jvm内存尽可能大,因为每次排序都需要1千万个数据拷贝,所以使用-XX:+PrintGCDetails打印了GC细节以及GC带来的停顿,挺好玩的

1.算法代码和测试代码

MergeSort的代码如下:

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
 * Created by violetMoon on 2016/4/13.
 */
public class MergeSort {

    public static <T extends Comparable<? super T>> void sort(T[] datas) {
        sort(datas, 0, datas.length - 1);
    }

    public static <T extends Comparable<? super T>> void sort(T[] datas, int low, int high) {
        Object[] bufs = new Object[datas.length];
        sort(datas, low, high, bufs);
    }

    public static <T extends Comparable<? super T>> void parallelSort(T[] datas) {
        parallelSort(datas, 0, datas.length - 1);
    }

    public static <T extends Comparable<? super T>> void parallelSort(T[] datas, int low, int high) {
        ForkJoinPool pool = new ForkJoinPool();
        Object[] bufs = new Object[datas.length];
        MergeSortAction action = new MergeSortAction(datas, low, high, bufs);
        pool.invoke(action);
        pool.shutdown();
        try {
            pool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return;
    }

    static <T extends Comparable<? super T>> void sort(T[] datas, int low, int high, Object[] bufs) {
        if (low >= high)
            return;

        int center = (low + high) / 2;
        sort(datas, low, center);
        sort(datas, center + 1, high);
        merge(datas, low, center, high, bufs);
    }

    private static <T extends Comparable<? super T>> void merge(T[] datas, int leftPos, int center, int rightPos,
                                                                Object[] bufs) {
        int cur = leftPos;
        int leftCur = leftPos;
        int rightCur = center + 1;
        while (leftCur <= center && rightCur <= rightPos) {
            if (datas[leftCur].compareTo(datas[rightCur]) < 0)
                bufs[cur++] = datas[leftCur++];
            else
                bufs[cur++] = datas[rightCur++];
        }
        while (leftCur <= center)
            bufs[cur++] = datas[leftCur++];
        while (rightCur <= rightPos)
            bufs[cur++] = datas[rightCur++];

        for (int i = leftPos; i <= rightPos; ++i)
            datas[i] = (T) bufs[i];
    }

    static class MergeSortAction<T extends Comparable<? super T>> extends RecursiveAction {

        private T[] mDatas;
        private int mLow;
        private int mHigh;
        private Object[] mBuf;

        public MergeSortAction(T[] datas, int low, int high, Object[] buf) {
            mDatas = datas;
            mLow = low;
            mHigh = high;
            mBuf = buf;
        }

        @Override
        protected void compute() {
            if (mLow >= mHigh)
                return;

            int center = (mLow + mHigh) / 2;
            if (mLow + 2 >= mHigh) {
                if (mDatas[mLow].compareTo(mDatas[center]) > 0)
                    swap(mDatas, mLow, center);
                if (mDatas[center].compareTo(mDatas[mHigh]) > 0) {
                    swap(mDatas, center, mHigh);
                    if (mDatas[mLow].compareTo(mDatas[center]) > 0)
                        swap(mDatas, mLow, center);
                }
                return;
            }
            invokeAll(new MergeSortAction<>(mDatas, mLow, center, mBuf), new MergeSortAction<T>(mDatas, center + 1,
                    mHigh, mBuf));
            merge(mDatas, mLow, center, mHigh, mBuf);
        }

        private void swap(Object[] datas, int a, int b) {
            Object tmp = datas[a];
            datas[a] = datas[b];
            datas[b] = tmp;
        }
    }
}

归并排序可以将问题分成独立的子问题,所以很适合使用ForkJoinPool进行并行排序,以充分利用多核cpu,ForkJoinPool的默认构造器创建的线程数量为CPU核心数。归并排序改进后好像是不需要额外的O(n)存储空间需求的,不过在这里影响不大,创建一个1千万的数组也不过二三十毫秒

QuickSort的代码如下,之前写的时候提供的接口提供了async参数来控制是顺排还是倒排

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;

/**
 * Created by violetMoon on 2016/4/11.
 */
public class QuickSort {

    public static <T extends Comparable<? super T>> void quickSort(T[] datas) {
        quickSort(datas, true);
    }

    public static <T extends Comparable<? super T>> void quickSort(T[] datas, boolean async) {
        quickSort(datas, 0, datas.length - 1, async);
    }

    public static <T extends Comparable<? super T>> void parallelQuickSort(T[] datas) {
        ForkJoinPool pool = new ForkJoinPool();
        QuickSortAction<T> action = new QuickSortAction<>(datas, 0, datas.length - 1, true);
        pool.invoke(action);
        pool.shutdown();
        try {
            pool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return;
    }

    public static <T extends Comparable<? super T>> void quickSort(T[] datas, int low, int high) {
        quickSort(datas, low, high, true);
    }

    public static <T extends Comparable<? super T>> void quickSort(T[] datas, int low, int high, boolean async) {
        if (low >= high)
            return;

        int pivot;
        if (async)
            pivot = partitionAsync(datas, low, high);
        else
            pivot = partitionDesc(datas, low, high);

        quickSort(datas, low, pivot - 1, async);
        quickSort(datas, pivot + 1, high, async);
    }

    /**
     * datas应该允许被改变
     *
     * @param datas
     * @param low
     * @param high
     * @param n
     * @param <T>
     * @return
     */
    public static <T extends Comparable<? super T>> T selectNth(T[] datas, int low, int high, int n) {
        if (low == high)
            return datas[low];
        int pivot = partitionAsync(datas, low, high);
        int distance = pivot - low - n;
        if (distance == 0) {
            return datas[pivot];
        } else if (distance > 0) {
            return selectNth(datas, low, pivot - 1, n);
        } else {
            return selectNth(datas, pivot + 1, high, -distance - 1);
        }
    }

    private static <T extends Comparable<? super T>> int partitionAsync(T[] datas, int low, int high) {
        T key = datas[low];
        while (low < high) {
            while (low < high && datas[high].compareTo(key) >= 0)
                high--;
            datas[low] = datas[high];
            while (low < high && datas[low].compareTo(key) <= 0)
                low++;
            datas[high] = datas[low];
        }
        datas[low] = key;
        return low;
    }

    private static <T extends Comparable<? super T>> int partitionDesc(T[] datas, int low, int high) {
        T key = datas[low];
        while (low < high) {
            while (low < high && datas[high].compareTo(key) <= 0)
                high--;
            datas[low] = datas[high];
            while (low < high && datas[low].compareTo(key) >= 0)
                low++;
            datas[high] = datas[low];
        }
        datas[low] = key;
        return low;
    }

    static class QuickSortAction<T extends Comparable<? super T>> extends RecursiveAction {

        private T[] datas;
        private int low;
        private int high;
        private boolean async;

        public QuickSortAction(T[] datas, int low, int high, boolean async) {
            this.datas = datas;
            this.low = low;
            this.high = high;
            this.async = async;
        }

        @Override
        protected void compute() {
            if (low >= high)
                return;

            if (low + 2 >= high) {
                int center = (low + high) / 2;
                if (async) {
                    if (datas[low].compareTo(datas[center]) > 0)
                        swap(datas, low, center);
                    if (datas[center].compareTo(datas[high]) > 0) {
                        swap(datas, center, high);
                        if (datas[low].compareTo(datas[center]) > 0)
                            swap(datas, low, center);
                    }
                } else {
                    if (datas[low].compareTo(datas[center]) < 0)
                        swap(datas, low, center);
                    if (datas[center].compareTo(datas[high]) < 0) {
                        swap(datas, center, high);
                        if (datas[low].compareTo(datas[center]) < 0)
                            swap(datas, low, center);
                    }
                }
                return;
            }

            int pivot;
            if (async)
                pivot = partitionAsync(datas, low, high);
            else
                pivot = partitionDesc(datas, low, high);

            QuickSortAction<T> sortLeft = new QuickSortAction<>(datas, low, pivot - 1, async);
            QuickSortAction<T> sortRight = new QuickSortAction<>(datas, pivot + 1, high, async);
            invokeAll(sortLeft, sortRight);
        }

        private static void swap(Object[] datas, int a, int b) {
            Object tmp = datas[a];
            datas[a] = datas[b];
            datas[b] = tmp;
        }
    }
}

同样使用了ForkJoinPool

测试代码:

static Integer[] getRandomData(int length) {
        Integer[] datas = new Integer[length];
        Random random = new Random(System.currentTimeMillis());
        for (int i=0; i<length; ++i)
            datas[i] = random.nextInt();
        return datas;
    }

    static <T extends Comparable<? super T>> void evaluteSortMethod(Consumer<T[]> function, T[] datas, String funcName) {
        long beginMillis = System.currentTimeMillis();
        function.accept(datas);
        long usedMillis = System.currentTimeMillis() - beginMillis;
        System.out.println(funcName + " 使用了" + (usedMillis / 1000) + "秒" + (usedMillis % 1000) + "毫秒来排序"
                + datas.length + "个数据");
        if (!isSorted(datas))
            System.err.println(funcName + "排序错误");
    }

    static <T extends Comparable<? super T>> long evaluteSortMethod(Consumer<T[]> function, T[] datas) {
        long beginMillis = System.currentTimeMillis();
        function.accept(datas);
        long usedMillis = System.currentTimeMillis() - beginMillis;
        if (!isSorted(datas))
            throw new IllegalStateException("算法排序错误");
        return usedMillis;
    }

    static void test(int testDataLen) {
        Integer[] src = getRandomData(testDataLen);

        Integer[] srcCopy = new Integer[src.length];
        for (int i=0; i<src.length; ++i)
            srcCopy[i] = src[i];

        Integer[] srcCopy2 = new Integer[src.length];
        for (int i=0; i<src.length; ++i)
            srcCopy2[i] = src[i];

        Integer[] srcCopy3 = new Integer[src.length];
        for (int i=0; i<src.length; ++i)
            srcCopy3[i] = src[i];

        Integer[] srcCopy4 = new Integer[src.length];
        for (int i=0; i<src.length; ++i)
            srcCopy4[i] = src[i];

        //evaluteSortMethod(MergeSort::sort, srcCopy, "mergeSort");
        evaluteSortMethod(MergeSort::parallelSort, src, "MergeSort::parallelSort");
        evaluteSortMethod(Arrays::sort, srcCopy2, "Arrays::sort");
        evaluteSortMethod(Arrays::parallelSort, srcCopy, "Arrays::parallelSort");
        evaluteSortMethod(QuickSort::quickSort, srcCopy3, "QuickSort:quickSort");
        evaluteSortMethod(QuickSort::parallelQuickSort, srcCopy4, "QuickSort:parallelQuickSort");

        //测试排序算法的平均耗时
        /*long totalMillis = 0;
        int testTimes = 1;
        for (int i=0; i<testTimes; ++i) {
            Integer[] srcCopy4 = new Integer[src.length];
            for (int j=0; j<src.length; ++j)
                srcCopy4[j] = src[j];
            totalMillis += evaluteSortMethod(QuickSort::parallelQuickSort, srcCopy4);
            System.out.println("test times " + i);
            System.out.println(QuickSort.QuickSortAction.count.get());
        }

        System.out.println((double)totalMillis / testTimes);*/
    }

    static <T extends Comparable<? super T>> boolean isSorted(T[] datas) {
        for (int i=0; i<datas.length - 1; ++i)
            if (datas[i].compareTo(datas[i +  1]) > 0)
                return false;
        return true;
    }

    public static void main(String[] args) {
        test(10);
        System.out.println("-- runtime environment --");
        System.getProperties().list(System.out);
        System.out.println("available processor num=" + Runtime.getRuntime().availableProcessors());
        System.out.println("-------------");
        test(10000000);
    }

这里使用了JDK8中的lambda表达式来简化代码的编写

2.实验环境

8核CPU、4G内存

3.实验结果

实验发现没有并行的归并排序是很慢的,和其他方法的不在一个数量级(或许是我算法写得有点问题?),我先是跑一次测一次的方法记录了算法用时,QuickSort::sort、Arrays::sort平均使用了4秒,MergeSort::parallel平均使用了3秒,Arrays::quickSort平均使用了2.1秒,QuickSort::parallelSort平均使用了1.7秒

然后我用下面注释掉的代码测Arrays::parallelSort和QuickSort::parallelSort的平均用时,发现Arrays::parallelSort的算法用时在1.3秒左右,QuickSort::parallelSort的算法用时在1.6秒左右,可能是因为每次parallelSort执行时都新建一个ForkJoinPool的原因?还有打印发现在4G的JVM内存下,Arrays::parallelSort平均10次发生一次GC,QuickSort::parallelSort平均3次发生一次GC,QuickSort::parallelSort分配的内存主要在递归时分配线程方法栈还有QuickSortAction对象的内存(分配了大概800万个...),这些都可能是瓶颈吧,还没仔细研究过Arrays::parallelSort,以后有空的时候再回来看看。