优化一:有序数组优化

在合并数组前进行判断,如果左边数组的最大值都小于右边数组的最小值,就不用进行合并操作

对于完全有序的数组,递归树中每层都不用进行合并,最后一层叶子节点需要操作的次数是n,其上一层为1/2n,1/4n...最后总共需要操作的次数为2n

因此归并排序法的时间复杂度会降为O(n)

归并排序法02:归并排序法的优化_i++

import java.util.Arrays;
import java.util.Random;

public class Algorithm {

    public static void main(String[] args) {

        Integer[] testScale = {10000, 100000};

        for (Integer n : testScale) {

            Integer[] randomArr = ArrayGenerator.generatorRandomArray(n, n);

            Integer[] sortedArr = ArrayGenerator.generatorSortedArray(n, n);

            Integer[] arr1 = Arrays.copyOf(randomArr, randomArr.length);

            Integer[] arr2 = Arrays.copyOf(sortedArr, sortedArr.length);

            System.out.println("测试随机数组排序性能");
            System.out.println();

            Verify.testTime("sort", randomArr);
            Verify.testTime("sortOptimised1", arr1);

            System.out.println();

            System.out.println("测试有序数组排序性能");
            System.out.println();

            Verify.testTime("sort", sortedArr);
            Verify.testTime("sortOptimised1", arr2);

            System.out.println();
        }
    }
}

class MergeSort {

    private MergeSort(){}

    public static<E extends Comparable<E>> void sort(E[] arr){
        sort(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sort(E[] arr, int left, int right){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sort(arr, left, mid);
        sort(arr, mid + 1, right);

        merge(arr, left, mid, right);
    }

    public static<E extends Comparable<E>> void sortOptimised1(E[] arr){
        sortOptimised1(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sortOptimised1(E[] arr, int left, int right){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sortOptimised1(arr, left, mid);
        sortOptimised1(arr, mid + 1, right);

        /**
         * 优化一:有序数组优化,合并前先判断一下是否需要合并
         */
        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            merge(arr, left, mid, right);
        }
    }

    public static<E extends Comparable<E>> void merge(E[] arr, int left, int mid, int right) {

        int i = left;
        int j = mid + 1;

        E[] tem = Arrays.copyOfRange(arr, left, right + 1);

        for (int n = left; n < right + 1; n++) {

            if (i == mid + 1){
                arr[n] = tem[j - left];
                j++;
            }
            else if (j == right + 1) {
                arr[n] = tem[i - left];
                i++;
            }
            else if (tem[i - left].compareTo(tem[j - left]) <= 0) {
                arr[n] = tem[i - left];
                i++;
            }
            else{
                arr[n] = tem[j - left];
                j++;
            }
        }
    }
}

class ArrayGenerator {

    private ArrayGenerator (){}

    public static Integer[] generatorRandomArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        Random random = new Random();

        for (int i = 0; i < n; i++) {

            arr[i] = random.nextInt(maxBound);
        }

        return arr;
    }

    public static Integer[] generatorSortedArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        for (int i = 0; i < n; i++) {

            arr[i] = i;
        }

        return arr;
    }
}

class Verify {

    private Verify (){}

    public static<E extends Comparable<E>> boolean isSorted(E[] arr){

        for (int i = 0; i < arr.length - 1; i++) {
            if (arr[i].compareTo(arr[i + 1]) > 0) {
                return false;
            }
        }

        return true;
    }

    public static<E extends Comparable<E>> void testTime(String AlgorithmName, E[] arr) {

        long startTime = System.nanoTime();

        if (AlgorithmName.equals("sort")) {
            MergeSort.sort(arr);
        }

        if (AlgorithmName.equals("sortOptimised1")) {
            MergeSort.sortOptimised1(arr);
        }

        long endTime = System.nanoTime();

        if (!Verify.isSorted(arr)){
            throw new RuntimeException(AlgorithmName + "算法排序失败!");
        }

        System.out.println(String.format("%s算法,测试用例为%d,执行时间:%f秒", AlgorithmName, arr.length, (endTime - startTime) / 1000000000.0));
    }
}

优化二:使用插入排序法优化(不稳定)

虽然归并排序法的时间复杂度小于插入排序法,但是merge()方法包含大量的if判断和赋值语句,常数级别的语句很多,在小规模的排序时,可能插入排序法的性能反而更好

因此,可以在arr[left, right]区间元素较少的情况时,调用插入排序法来进行排序

import java.util.Arrays;
import java.util.Random;

public class Algorithm {

    public static void main(String[] args) {

        Integer[] testScale = {10000, 100000};

        for (Integer n : testScale) {

            Integer[] randomArr = ArrayGenerator.generatorRandomArray(n, n);

            Integer[] sortedArr = ArrayGenerator.generatorSortedArray(n, n);

            Integer[] arr1 = Arrays.copyOf(randomArr, randomArr.length);
            Integer[] arr3 = Arrays.copyOf(randomArr, randomArr.length);
            Integer[] arr5 = Arrays.copyOf(randomArr, randomArr.length);

            Integer[] arr2 = Arrays.copyOf(sortedArr, sortedArr.length);
            Integer[] arr4 = Arrays.copyOf(sortedArr, sortedArr.length);
            Integer[] arr6 = Arrays.copyOf(sortedArr, sortedArr.length);

            System.out.println("测试随机数组排序性能");
            System.out.println();

            Verify.testTime("InsertionSort", arr3);
            Verify.testTime("sort", randomArr);
            Verify.testTime("sortOptimised1", arr1);
            Verify.testTime("sortOptimised2", arr5);

            System.out.println();

            System.out.println("测试有序数组排序性能");
            System.out.println();

            Verify.testTime("InsertionSort", arr4);
            Verify.testTime("sort", sortedArr);
            Verify.testTime("sortOptimised1", arr2);
            Verify.testTime("sortOptimised2", arr6);

            System.out.println();
        }
    }
}

class MergeSort {

    private MergeSort(){}

    public static<E extends Comparable<E>> void sort(E[] arr){
        sort(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sort(E[] arr, int left, int right){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sort(arr, left, mid);
        sort(arr, mid + 1, right);

        merge(arr, left, mid, right);
    }

    /**
     * 优化一:有序数组优化
     */
    public static<E extends Comparable<E>> void sortOptimised1(E[] arr){
        sortOptimised1(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sortOptimised1(E[] arr, int left, int right){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sortOptimised1(arr, left, mid);
        sortOptimised1(arr, mid + 1, right);

        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            merge(arr, left, mid, right);
        }
    }

    /**
     * 优化二:使用插入排序法优化
     */
    public static<E extends Comparable<E>> void sortOptimised2(E[] arr){
        sortOptimised2(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sortOptimised2(E[] arr, int left, int right){

        if (right - left <= 15) {
            InsertionSort.sort(arr, left, right);
            return;
        }

        int mid = left + (right - left) / 2;

        sortOptimised1(arr, left, mid);
        sortOptimised1(arr, mid + 1, right);

        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            merge(arr, left, mid, right);
        }
    }

    public static<E extends Comparable<E>> void merge(E[] arr, int left, int mid, int right) {

        int i = left;
        int j = mid + 1;

        E[] tem = Arrays.copyOfRange(arr, left, right + 1);

        for (int n = left; n < right + 1; n++) {

            if (i == mid + 1){
                arr[n] = tem[j - left];
                j++;
            }
            else if (j == right + 1) {
                arr[n] = tem[i - left];
                i++;
            }
            else if (tem[i - left].compareTo(tem[j - left]) <= 0) {
                arr[n] = tem[i - left];
                i++;
            }
            else{
                arr[n] = tem[j - left];
                j++;
            }
        }
    }
}

class InsertionSort {

    private InsertionSort() {}

    public static <E extends Comparable> void sort(E[] arr) {

        for (int i = 1; i < arr.length; i++) {

            E tem = arr[i];

            int j;

            for (j = i; j > 0 && tem.compareTo(arr[j - 1]) < 0; j--) {

                arr[j] = arr[j - 1];
            }

            arr[j] = tem;
        }
    }

    /**
     * 增加分区间插入排序
     */
    public static <E extends Comparable> void sort(E[] arr, int left, int right) {

        for (int i = left + 1; i < right + 1; i++) {

            E tem = arr[i];

            int j;

            for (j = i; j > left && tem.compareTo(arr[j - 1]) < 0; j--) {

                arr[j] = arr[j - 1];
            }

            arr[j] = tem;
        }
    }
}

class ArrayGenerator {

    private ArrayGenerator (){}

    public static Integer[] generatorRandomArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        Random random = new Random();

        for (int i = 0; i < n; i++) {

            arr[i] = random.nextInt(maxBound);
        }

        return arr;
    }

    public static Integer[] generatorSortedArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        for (int i = 0; i < n; i++) {

            arr[i] = i;
        }

        return arr;
    }
}

class Verify {

    private Verify (){}

    public static<E extends Comparable<E>> boolean isSorted(E[] arr){

        for (int i = 0; i < arr.length - 1; i++) {
            if (arr[i].compareTo(arr[i + 1]) > 0) {
                return false;
            }
        }

        return true;
    }

    public static<E extends Comparable<E>> void testTime(String AlgorithmName, E[] arr) {

        long startTime = System.nanoTime();

        if (AlgorithmName.equals("sort")) {
            MergeSort.sort(arr);
        }

        if (AlgorithmName.equals("sortOptimised1")) {
            MergeSort.sortOptimised1(arr);
        }

        if (AlgorithmName.equals("sortOptimised2")) {
            MergeSort.sortOptimised2(arr);
        }

        if (AlgorithmName.equals("InsertionSort")) {
            InsertionSort.sort(arr);
        }

        long endTime = System.nanoTime();

        if (!Verify.isSorted(arr)){
            throw new RuntimeException(AlgorithmName + "算法排序失败!");
        }

        System.out.println(String.format("%s算法,测试用例为%d,执行时间:%f秒", AlgorithmName, arr.length, (endTime - startTime) / 1000000000.0));
    }
}

优化三:内存优化

merge()方法在每次调用时,都会开辟一个新的数组来接收传来的数组,当数据规模很大时,空间消耗也会损失很多性能

因此在排序之前先创建一个原数组的副本temp,每次merge()方法就在temp中进行赋值和读取,节省了大量的空间

import java.util.Arrays;
import java.util.Random;

public class Algorithm {

    public static void main(String[] args) {

        Integer[] testScale = {10000, 500000};

        for (Integer n : testScale) {

            Integer[] randomArr = ArrayGenerator.generatorRandomArray(n, n);

            Integer[] sortedArr = ArrayGenerator.generatorSortedArray(n, n);

            Integer[] arr1 = Arrays.copyOf(randomArr, randomArr.length);
            Integer[] arr3 = Arrays.copyOf(randomArr, randomArr.length);

            Integer[] arr2 = Arrays.copyOf(sortedArr, sortedArr.length);
            Integer[] arr4 = Arrays.copyOf(sortedArr, sortedArr.length);

            System.out.println("测试随机数组排序性能");
            System.out.println();

            Verify.testTime("sortOptimised1", arr1);
            Verify.testTime("sortOptimised2", arr3);

            System.out.println();

            System.out.println("测试有序数组排序性能");
            System.out.println();

            Verify.testTime("sortOptimised1", arr2);
            Verify.testTime("sortOptimised2", arr4);

            System.out.println();
        }
    }
}

class MergeSort {

    private MergeSort(){}

    /**
     * 优化一:有序数组优化,合并前先判断一下是否需要合并
     */
    public static<E extends Comparable<E>> void sortOptimised1(E[] arr){

        sortOptimised1(arr, 0, arr.length - 1);
    }

    private static<E extends Comparable<E>> void sortOptimised1(E[] arr, int left, int right){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sortOptimised1(arr, left, mid);
        sortOptimised1(arr, mid + 1, right);


        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            merge(arr, left, mid, right);
        }
    }

    /**
     * 优化三:内存操作优化
     */
    public static<E extends Comparable<E>> void sortOptimised2(E[] arr){

        /**
         * 提前将arr数组保存一个副本,这样就不用每次调用merge()方法,都重新开辟空间新建一个数组了
         */
        E[] temp = Arrays.copyOf(arr, arr.length);

        sortOptimised2(arr, 0, arr.length - 1, temp);
    }

    private static<E extends Comparable<E>> void sortOptimised2(E[] arr, int left, int right, E[] temp){

        if (left >= right){
            return;
        }

        int mid = left + (right - left) / 2;

        sortOptimised2(arr, left, mid, temp);
        sortOptimised2(arr, mid + 1, right, temp);

        if (arr[mid].compareTo(arr[mid + 1]) > 0) {
            mergeOptimised(arr, left, mid, right, temp);
        }
    }

    public static<E extends Comparable<E>> void merge(E[] arr, int left, int mid, int right) {

        int i = left;
        int j = mid + 1;

        E[] tem = Arrays.copyOfRange(arr, left, right + 1);

        for (int n = left; n < right + 1; n++) {

            if (i == mid + 1){
                arr[n] = tem[j - left];
                j++;
            }
            else if (j == right + 1) {
                arr[n] = tem[i - left];
                i++;
            }
            else if (tem[i - left].compareTo(tem[j - left]) <= 0) {
                arr[n] = tem[i - left];
                i++;
            }
            else{
                arr[n] = tem[j - left];
                j++;
            }
        }
    }

    /**
     * 优化merge()方法,复用temp数组,节省空间
     */
    public static<E extends Comparable<E>> void mergeOptimised(E[] arr, int left, int mid, int right, E[] temp) {

        int i = left;
        int j = mid + 1;

        /**
         * System.arraycopy()方法将传过来的排好序的两个分数组在相同位置赋值给副本数组temp,因此索引范围一致没有偏移
         */
        System.arraycopy(arr, left, temp, left, right - left + 1);

        for (int n = left; n < right + 1; n++) {

            if (i == mid + 1){
                arr[n] = temp[j];
                j++;
            }
            else if (j == right + 1) {
                arr[n] = temp[i];
                i++;
            }
            else if (temp[i].compareTo(temp[j]) <= 0) {
                arr[n] = temp[i];
                i++;
            }
            else{
                arr[n] = temp[j];
                j++;
            }
        }
    }
}

class ArrayGenerator {

    private ArrayGenerator (){}

    public static Integer[] generatorRandomArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        Random random = new Random();

        for (int i = 0; i < n; i++) {

            arr[i] = random.nextInt(maxBound);
        }

        return arr;
    }

    public static Integer[] generatorSortedArray (Integer n, Integer maxBound){

        Integer[] arr = new Integer[n];

        for (int i = 0; i < n; i++) {

            arr[i] = i;
        }

        return arr;
    }
}

class Verify {

    private Verify (){}

    public static<E extends Comparable<E>> boolean isSorted(E[] arr){

        for (int i = 0; i < arr.length - 1; i++) {
            if (arr[i].compareTo(arr[i + 1]) > 0) {
                return false;
            }
        }

        return true;
    }

    public static<E extends Comparable<E>> void testTime(String AlgorithmName, E[] arr) {

        long startTime = System.nanoTime();

        if (AlgorithmName.equals("sortOptimised1")) {
            MergeSort.sortOptimised1(arr);
        }

        if (AlgorithmName.equals("sortOptimised2")) {
            MergeSort.sortOptimised2(arr);
        }

        long endTime = System.nanoTime();

        if (!Verify.isSorted(arr)){
            throw new RuntimeException(AlgorithmName + "算法排序失败!");
        }

        System.out.println(String.format("%s算法,测试用例为%d,执行时间:%f秒", AlgorithmName, arr.length, (endTime - startTime) / 1000000000.0));
    }
}