什么是线段树

线段树(Segment Tree)也叫区间树,其本质上是一种二分搜索树,不同点在于线段树中每个节点不再是存放单纯的元素,而是存放了一个可以表示区间的值,通常是该区间合并后的值。并且每个区间会被平均分为2个子区间,作为它的左右子节点。比如说根节点存放了区间 [1,10],那么就会被分为区间 [1,5] 作为左子节点,区间 [6,10] 作为右子节点。

例如,我们可以将这样一个数组所表示的区间构造成线段树:
数据结构之线段树

并且指定区间合并规则为区间内的元素求和,那么构造出来的线段树表示如下:
数据结构之线段树

  • 从这颗线段树可以看到,由于具有二分搜索树的特性,我们可以快速地在线段树中找到一个区间。注意,这里是指按区间查找,而不是按元素值查找。所以线段树相对难理解的地方就在于每个节点既有区间的概念又有一个元素值。

为什么要使用线段树

关于线段树的一个经典问题就是:区间染色。假设有一面墙,长度为 n,每次选择一段儿墙进行染色。在 m 次操作后,我们可以在 [i, j] 区间内看见多少中颜色?

对于这个问题,我们可以使用一个数组来实现:
数据结构之线段树

对于染色操作(更新区间)我们可以遍历数组找到目标区间进行染色,时间复杂度是 $O(n)$。对于查询操作(查询区间)也是遍历数组即可,同样时间复杂度为 $O(n)$。显然用线性结构来解决这类问题的时间复杂度要更高一些,此时线段树就派上用场了,因为树形结构的时间复杂度通常在 $O(logn)$。

除此之外,线段树的另一个经典问题就是:区间查询。查询一个区间 [i, j] 的最大值和最小值,或者区间数字之和。例如,在实际业务中很常见的基于区间的统计查询:2017年注册用户中消费最高的用户?消费最少的用户?学习时间最长的用户?某个太空区间中天体总量?

对于静态区间数据(区间内的数据不会发生变化)来说,是比较好解决的,但以上所提到的问题都是动态的区间数据(区间内的数据在不断的变化),此时线段树就是一个比较好的选择。

通过以上的介绍,我们能总结出线段树的两个核心操作:

  • 区间更新:更新区间中一个元素或者一个区间的值
  • 区间查询:查询一个区间 [i, j] 的最大值、最小值,或者区间数字之和

线段树基础表示

线段树虽然不像堆那样是一棵完全二叉树,但线段树由于其特性满足平衡二叉树(左右子树高度相差不超过1),所以依然可以使用数组进行表示。我们可以将其看做是一颗满二叉树,空节点就当做叶子节点即可。如下示例:
数据结构之线段树

既然可以用数组来表示一棵线段树,那么如果区间有 n 个元素,此时应该创建多大容量的数组来构建一颗线段树呢?对于这个问题,我们先来看如何求一棵满二叉树的节点:假设这棵树有 h 层,那么这棵树就一共有 $2^h-1$ 个节点(大约是 $2^h$)。对于最后一层($h - 1$ 层)来说,就有 $2^{(h-1)}$ 个节点。因此,最后一层的节点数大致等于前面所有层节点之和。

了解了如何求满二叉树的节点数量后,回到之前的问题,如果区间有 n 个元素,此时应该开多大空间的数组?我们可以分成两种情况:

  • 如果 $n = 2^k$,那么只需要开辟 $2n$ 的数组空间
  • 如果 $n = 2^k + 1$,那么就需要开辟 $4n$ 的数组空间

通常来说,我们的线段树不考虑添加元素,即区间固定(区间内的数据可以是不固定的),那么使用 $4n$ 的静态空间即可。这也是普遍构造线段树时,使用的一个通用值。除非对内存有严格要求,否则一般开辟 $4n$ 的数组空间即可。而且对于内存有要求的情况下,一般也不会采用数组来表示,此时链式结会是更优的选择。

接下来,我们就实现一下线段树的基础结构代码:

package tree;

/**
 * 线段树 - 基于数组的表示实现
 *
 * @author 01
 * @date 2021-01-27
 **/
public class SegmentTree<E> {

    /**
     * 保存原始数组,即需要被构造成线段树的区间
     */
    private E[] data;

    /**
     * 线段树的数组表示
     */
    private E[] tree;

    public SegmentTree(E[] arr) {
        this.data = (E[]) new Object[arr.length];
        System.arraycopy(arr, 0, this.data, 0, arr.length);

        // 开辟 4n 的数组空间用于构造线段树
        this.tree = (E[]) new Object[4 * arr.length];
    }

    public int getSize() {
        return data.length;
    }

    public E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("Index is illegal");
        }

        return data[index];
    }

    /**
     * 返回完全二叉树的数组表示中,一个索引所表示的元素的左子节点的索引
     */
    private int leftChild(int index) {
        return 2 * index + 1;
    }

    /**
     * 返回完全二叉树的数组表示中,一个索引所表示的元素的右子节点的索引
     */
    private int rightChild(int index) {
        return 2 * index + 2;
    }
}

创建线段树

在本小节中,我们来根据之前实现的基础代码,完成创建线段树逻辑的编写。需要说明一下的是,在本例中,线段树每个节点所存储的元素是区间合并后的值。具体的实现代码如下:

/**
 * 用户自定义的区间合并逻辑
 */
private final Merger<E> merger;

public SegmentTree(E[] arr, Merger<E> merger) {
    this.merger = merger;
    this.data = (E[]) new Object[arr.length];
    System.arraycopy(arr, 0, this.data, 0, arr.length);

    // 开辟 4n 的数组空间用于构建线段树
    this.tree = (E[]) new Object[4 * arr.length];
    // 构建线段树,传入根节点索引,以及区间的左右端点
    buildSegmentTree(0, 0, data.length - 1);
}

/**
 * 在treeIndex的位置创建表示区间[left...right]的线段树
 */
private void buildSegmentTree(int treeIndex, int left, int right) {
    // 区间中只有一个元素,代表递归到底了
    if (left == right) {
        tree[treeIndex] = data[left];
        return;
    }

    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    // 计算中间点,需要避免整型溢出
    int mid = left + (right - left) / 2;
    // 构建左子树
    buildSegmentTree(leftTreeIndex, left, mid);
    // 构建右子树
    buildSegmentTree(rightTreeIndex, mid + 1, right);

    // 对于两个区间的合并规则是与业务相关的,所以要调用用户自定义的逻辑来完成
    tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

/**
 * 遍历打印树中节点中值信息。
 *
 * @return String
 */
@Override
public String toString() {
    StringBuilder res = new StringBuilder();
    res.append('[');
    for (int i = 0; i < tree.length; i++) {
        if (tree[i] != null) {
            res.append(tree[i]);
        } else {
            res.append("null");
        }

        if (i != tree.length - 1) {
            res.append(", ");
        }
    }
    res.append(']');

    return res.toString();
}
  • 在线段树中根节点存储的数据,实际就是左右两个子节点数据的合并(递归即可),而具体如何合并是由业务决定的。例如,可以是求和,也可以是求最大值或最小值。另外,这里没有通过一个对象来表示节点中的左右区间,而是通过方法参数的形式表示了这个区间,数组中只存储区间合并后的值。

用户传入的 Merger 是一个接口,其定义如下:

package tree;

/**
 * 合并器接口
 *
 * @author 01
 * @date 2021-01-27
 **/
public interface Merger<E> {

    /**
     * 用户自定义的区间合并逻辑
     *
     * @param a 区间a
     * @param b 区间b
     * @return 合并后的结果
     */
    E merge(E a, E b);
}

最后,我们来编写一个简单的测试用例进行一下测试:

package tree;

/**
 * 测试SegmentTree
 *
 * @author 01
 */
public class SegmentTreeTests {

    public static void main(String[] args) {
        Integer[] nums = {-2, 0, 3, -5, 2, -1};
        SegmentTree<Integer> segTree = new SegmentTree<>(
                nums, Integer::sum // 对两个区间中的值进行求和
        );
        System.out.println(segTree);
    }
}

输出结果如下:

[-3, 1, -4, -2, 3, -3, -1, -2, 0, null, null, -5, 2, null, null, null, null, null, null, null, null, null, null, null]
  • 可以看到,线段树的根节点是 -3 ,因为对整个数组的求和结果就是 -3 。左子节点为 1,因为 -2 + 0 + 3 = 1。右子节点为 -4,同理,因为 -5 + 2 + -1 = -4,其余以此类推。结果符合预期,证明我们实现的线段树没有问题。

线段树中的区间查询

例如,我们要对如下这棵线段树查询 [2, 5] 这个区间:
数据结构之线段树

由于我们之前传入的 Merger 实现的是求和逻辑,那么这相当于查询2 ~ 5区间所有元素的和。从根节点开始往下,我们知道分割位置,左节点查询 [2, 3],右节点查询 [4, 5],找到两个节点之后合并就可以了。

具体的实现代码如下:

/**
 * 查询区间[queryLeft, queryRight]的值,如[2, 5]
 */
public E query(int queryLeft, int queryRight) {
    if (queryLeft < 0 || queryLeft >= data.length ||
            queryRight < 0 || queryRight >= data.length ||
            queryLeft > queryRight) {
        throw new IllegalArgumentException("Index is illegal");
    }

    return query(0, 0,
            data.length - 1, queryLeft, queryRight);
}

/**
 * 在以treeIndex为根的线段树中[left...right]的范围里,搜索区间[queryLeft...queryRight]的值
 */
private E query(int treeIndex, int left, int right,
                int queryLeft, int queryRight) {
    // 找到了目标区间
    if (left == queryLeft && right == queryRight) {
        return tree[treeIndex];
    }

    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    // 计算中间点,需要避免整型溢出
    int mid = left + (right - left) / 2;

    if (queryLeft >= mid + 1) {
        // 目标区间不在左子树中,查找右子树
        return query(rightTreeIndex, mid + 1, right, queryLeft, queryRight);
    } else if (queryRight <= mid) {
        // 目标区间不在右子树中,查找左子树
        return query(leftTreeIndex, left, mid, queryLeft, queryRight);
    }

    // 目标区间一部分在右子树中,一部分在左子树中,则两个子树都需要找
    E leftResult = query(leftTreeIndex, left, mid, queryLeft, mid);
    E rightResult = query(rightTreeIndex, mid + 1, right, mid + 1, queryRight);

    // 找到目标区间的值,将其合并后返回
    return merger.merge(leftResult, rightResult);
}

进行一个简单的测试:

public static void main(String[] args) {
    Integer[] nums = {-2, 0, 3, -5, 2, -1};
    SegmentTree<Integer> segTree = new SegmentTree<>(
            nums, Integer::sum // 对两个区间中的值进行求和
    );

    System.out.println(segTree.query(0,2));
    System.out.println(segTree.query(2,5));
    System.out.println(segTree.query(0,5));
}

输出结果如下:

1
-1
-3

线段树中的更新操作

我们使用线段树来解决区间相关的问题,主要是针对区间内的数据是动态变化的情况,如果是静态区间一般不需要用到线段树。所以在本小节,我们就来实现线段树中的更新操作。

实际上线段树中的更新操作,本质上是在二分查找。因为根据线段树的特性,待更新的目标节点肯定是一个叶子节点,我们只需要找到这个叶子节点并进行更新即可。我们查找待更新节点的依据是数组的索引,而数组的索引是从 0 ~ n 有序的,所以在一个有序的区间中查找某个特定的值,妥妥的就是二分查找了。

知道了我们在更新线段树中某个节点时,要找的这个待更新节点是一个叶子节点,并且找到这个叶子节点的过程本质上是一个二分查找,那么这个思路就很清晰了。

首先,将找到叶子节点的条件作为递归的退出条件。然后计算中间点,并将线段树数组划分为 [left...mid][mid+1...right] 两个区间。接着判断要找的数组索引落在哪个区间,就继续往哪个区间递归查找。最后,将区间的值进行合并。如此一来,就完成了目标节点的更新操作。

具体的实现代码如下:

/**
 * 将index位置的值,更新为e
 */
public void set(int index, E e) {
    if (index < 0 || index >= data.length) {
        throw new IllegalArgumentException("Index is illegal");
    }

    data[index] = e;
    set(0, 0, data.length - 1, index, e);
}

/**
 * 在以treeIndex为根的线段树中更新index的值为e
 */
private void set(int treeIndex, int left, int right, int index, E e) {
    // 找到了叶子节点
    if (left == right) {
        // 进行更新
        tree[treeIndex] = e;
        return;
    }

    int mid = left + (right - left) / 2;
    // 将线段树数组划分为[left...mid]和[mid+1...right]两个区间
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (index >= mid + 1) {
        // index在右子树
        set(rightTreeIndex, mid + 1, right, index, e);
    } else {
        // index在左子树
        set(leftTreeIndex, left, mid, index, e);
    }

    tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

Leetcode上线段树相关的问题

在本文的最后,我们来使用自己实现的线段树解决一个Leetcode上的307号问题:

该问题的主要需求是更新数组下标对应的值,以及查询数组中某个区间内的元素总和。像这种对区间内数据有更新需求的,会使得区间内数据动态变化的,就很适合使用线段树来解决。具体的实现代码如下:

package tree.solution;

import tree.SegmentTree;

/**
 * Leetcode 307. Range Sum Query - Mutable
 * https://leetcode.com/problems/range-sum-query-mutable/description/
 */
class NumArray {

    private SegmentTree<Integer> segTree;

    public NumArray(int[] nums) {
        if (nums.length != 0) {
            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length; i++) {
                data[i] = nums[i];
            }
            segTree = new SegmentTree<>(data, Integer::sum);
        }
    }

    public void update(int i, int val) {
        if (segTree == null) {
            throw new IllegalArgumentException("Error");
        }
        segTree.set(i, val);
    }

    public int sumRange(int i, int j) {
        if (segTree == null) {
            throw new IllegalArgumentException("Error");
        }

        return segTree.query(i, j);
    }
}