高级数据结构（Ⅲ）线段树（Segment Tree）线段树的原理树的创建单点修改区间查找完整代码及测试

## 高级数据结构（Ⅲ）线段树（Segment Tree）

### 线段树的原理

class SegmentTree{    private static final int[] tree = new int[1000];    int[] arr;    SegmentTree() {    }    SegmentTree(int[] arr) {        this.arr = arr;    }    //创建树···    public void buildTree(){}        //单点修改更新树    public void updateTree(){}        //区间查找    public void queryTree(){}   }

### 树的创建

• 若low == high，此时令tree[node] = arr[low]，并终止递归
• 否则，将区间二分，分别计算左区间[low, mid]和右区间[mid +1, high],并在最后更新tree[node]

//创建树    public void buildTree() {        this.buildTree(0, 0, arr.length - 1);    }    private void buildTree(int node, int low, int high) {        if(low == high) {            tree[node] = arr[low];            return;        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        buildTree(lnode, low, mid);        buildTree(rnode, mid + 1, high);        tree[node] = tree[lnode] + tree[rnode];    }

### 单点修改

//单点修改更新树    public void updateTree(int index, int val) {        this.updateTree(0, 0, arr.length - 1, index, val);    }    private void updateTree(int node, int low, int high, int index, int val) {        if(low == high && low == index) {            arr[index] = val;            tree[node] = val;            return;        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        if(index >= low && index <= mid) {            updateTree(lnode, low, mid, index, val);        }else {            updateTree(rnode, mid + 1, high, index, val);        }        tree[node] = tree[lnode] + tree[rnode];    }

### 区间查找

• 若当前区间low > R 或者 high < L，说明已超出查找范围，返回0
• 若[low, high]处于区间[L, R]内，返回当前结点的值tree[node]

//区间查找    public int queryTree(int L, int R) {        return this.queryTree(0, 0, arr.length - 1, L, R);    }    private int queryTree(int node, int low,     int high, int L, int R) {        if(low > R || high < L) {            return 0;        }else if(low >= L && high <= R) {            return tree[node];        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        int sumleft  = queryTree(lnode, low, mid, L, R);        int sumRight = queryTree(rnode, mid + 1, high, L, R);        return  sumleft + sumRight;    }

### 完整代码及测试

class SegmentTree{    private static final int[] tree = new int[1000];    int[] arr;    SegmentTree() {    }    SegmentTree(int[] arr) {        this.arr = arr;    }    //创建树    public void buildTree() {        this.buildTree(0, 0, arr.length - 1);    }    private void buildTree(int node, int low, int high) {        if(low == high) {            tree[node] = arr[low];            return;        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        buildTree(lnode, low, mid);        buildTree(rnode, mid + 1, high);        tree[node] = tree[lnode] + tree[rnode];    }    //单点修改更新树    public void updateTree(int index, int val) {        this.updateTree(0, 0, arr.length - 1, index, val);    }    private void updateTree(int node, int low, int high, int index, int val) {        if(low == high && low == index) {            arr[index] = val;            tree[node] = val;            return;        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        if(index >= low && index <= mid) {            updateTree(lnode, low, mid, index, val);        }else {            updateTree(rnode, mid + 1, high, index, val);        }        tree[node] = tree[lnode] + tree[rnode];    }    //区间查找    public int queryTree(int L, int R) {        return this.queryTree(0, 0, arr.length - 1, L, R);    }    private int queryTree(int node, int low, int high, int L, int R) {        if(low > R || high < L) {            return 0;        }else if(low >= L && high <= R) {            return tree[node];        }        int mid = low + (high - low) / 2;        int lnode = 2 * node + 1;        int rnode = 2 * node + 2;        int sumleft  = queryTree(lnode, low, mid, L, R);        int sumRight = queryTree(rnode, mid + 1, high, L, R);        return  sumleft + sumRight;    }    //输出线段树的值    public void printTree() {        int size = 15;   //size值的大小由arr数组的大小而定        for (int i = 0; i < size; i++) {            System.out.print(tree[i] + " ");        }        System.out.println();    }}public class SegmentTreeTest {    public static void main(String[] args) {        int[] arr = {6, 4, 7, 5, 8, 3, 9};        SegmentTree st = new SegmentTree(arr);                //创建线段树        st.buildTree();        st.printTree();        //>>>42 22 20 10 12 11 9 6 4 7 5 8 3 0 0          //查找区间[3, 6]        int sum = st.queryTree(3, 6);        System.out.println(sum);        //>>>25        //单点修改更新树, 令arr[4] = 1        st.updateTree(4, 1);        st.printTree();        //>>>35 22 13 10 12 4 9 6 4 7 5 1 3 0 0     }}

import java.util.ArrayDeque;import java.util.Deque;class SegNode{    int val;    SegNode lnode;    SegNode rnode;    SegNode(){}    SegNode(int val) {        this.val = val;    }}class SegTree{    SegNode root;    int[] arr;    SegTree() {}    SegTree(int[] arr) {        this.arr = arr;        this.bulidTree();    }    //创建树    public void bulidTree() {        root = this.buildTree(0, arr.length - 1);    }    private SegNode buildTree(int low, int high) {        if(low == high) {            return new SegNode(arr[low]);        }        SegNode node = new SegNode();        int mid = low + (high - low) / 2;        node.lnode = buildTree(low, mid);        node.rnode = buildTree(mid + 1, high);        node.val = node.lnode.val + node.rnode.val;        return node;    }    //单点修改更新树    public void updateTree(int index, int val) {        root = updateTree(root ,0, arr.length - 1, index, val);    }    private SegNode updateTree(SegNode node, int low, int high, int index, int val) {        if(low == high && low == index) {            arr[index] = val;            node.val = val;            return node;        }        int mid = low + (high - low) / 2;        if(index >= low && index <= mid) {            node.lnode = updateTree(node.lnode, low, mid, index, val);        }else {            node.rnode = updateTree(node.rnode, mid + 1, high, index, val);        }        node.val = node.lnode.val + node.rnode.val;        return node;    }    //查找区间    public int queryTree(int L, int R) {        return queryTree(root, 0, arr.length - 1, L, R);    }    private int queryTree(SegNode node, int low, int high, int L ,int R) {        if(low > R || high < L) {            return 0;        }else if(low >= L && high <= R) {            return node.val;        }        int mid = low + (high - low) / 2;        int sumLeft  = queryTree(node.lnode, low, mid, L, R);        int sumRight = queryTree(node.rnode, mid + 1, high, L, R);        return  sumLeft + sumRight;    }    //输出树(层次遍历)    public void printTree() {        Deque<SegNode> queue = new ArrayDeque<SegNode>();        queue.offer(this.root);        while(!queue.isEmpty()) {            int size = queue.size();            for (int i = 0; i < size; i++) {                SegNode node = queue.poll();                System.out.print(node.val + " ");                if(node.lnode != null) queue.offer(node.lnode);                if(node.rnode != null) queue.offer(node.rnode);            }        }    }}public class SegmentTreeNodeTest {    public static void main(String[] args) {        int[] arr = {6, 4, 7, 5, 8, 3, 9};        //创建线段树        SegTree st = new SegTree(arr);        st.printTree();        System.out.println("");        //>>>42 22 20 10 12 11 9 6 4 7 5 8 3        //查找区间[3, 6]        int sum = st.queryTree(3, 6);        System.out.println(sum);        //>>>25        //单点修改更新树, 令arr[4] = 1        st.updateTree(4, 1);        st.printTree();        System.out.println("");        >>>35 22 13 10 12 4 9 6 4 7 5 1 3    }}