树状数组

算法训练营 树状数组 (Binary Indexed Tree(B.I.T), Fenwick Tree) 是一个查询和修改复杂度都为 log(n) 的数据结构。
「前缀和查询」与「单点更新」
直接前驱:c[i] 的直接前驱为 c[i - lowbid(i)],即 c[i] 左侧紧邻的子树的根。
直接后继:c[i] 的直接前驱为 c[i + lowbid(i)],即 c[i] 的父结点。
前驱:c[i] 左侧所有子树的根。
后继:c[i] 的所有祖先。

树状数组(Binary Indexed Tree (B.I.T))_List

★307. 区域和检索 - 数组可修改

「单点修改」和「区间查询」。[L,R] 区间和 = R 的前缀和减去 L - 1 的前缀和。

class NumArray {
    BIT bit;
    int[] nums;
    public NumArray(int[] nums) {
        int n = nums.length;
        this.nums = nums;
        bit = new BIT(n);
        for (int i = 0; i < n; i++){
            bit.update(i + 1, nums[i]);
        }
    }
    
    public void update(int i, int val) {
        bit.update(i + 1, val - nums[i]);
        nums[i] = val;        
    }
    
    public int sumRange(int left, int right) {
        return bit.query(right + 1) - bit.query(left);
    }
}
// 数状数组模板
class BIT {
    private int[] tree;
    private int n;

    public BIT(int n) {
        this.n = n;
        tree = new int[n + 1];
    }
    
    // 单点更新 
    public void update(int i, int delta) {            
        for (; i <= n; i += lowbit(i))  tree[i] += delta;
    }
    
    // 区间查询 前缀和
    public int query(int i) {
        int sum = 0;
        for (; i > 0; i -= lowbit(i)) sum += tree[i];
        return sum;
    }

    int lowbit(int x) {
        return x & (-x);
    }
}

时间复杂度:add 操作和 query 的复杂度都是 O(logn),因此构建数组的复杂度为 O(nlogn)。整体复杂度为 O(nlogn)
空间复杂度:O(n)

1649. 通过指令创建有序数组

class Solution {
    public int createSortedArray(int[] instructions) {
        int max_ = 1;
        for (int x : instructions) if (x > max_) max_ = x;
        int mod = (int) 1e9 + 7, n = instructions.length;
        long ans = 0;

        BIT bit = new BIT(max_);
        for (int i = 0; i < n; i++) {
            int x = instructions[i];
            int left = bit.query(x - 1);
            int mid = bit.query(x);
            int right = i - mid;
            ans += Math.min(left, right);
            bit.update(x, 1);
        }
        return (int) (ans % mod);
    }
}

▲1409. 查询带键的排列

class Solution {
    public int[] processQueries(int[] queries, int m) {
        int n = queries.length;
        BIT bit = new BIT(m + n); // 前面空出 n 个位置
        
        int[] pos = new int[m + 1];
        for (int i = 1; i <= m; ++i) {
            pos[i] = n + i; // 右移 n
            bit.update(n + i, 1);
        }
        
        int[] ans = new int[n];
        for (int i = 0; i < n; ++i) {
            int cur = pos[queries[i]];
            bit.update(cur, -1); // 移动
            ans[i] = bit.query(cur); // 位置 = 当前位置前面有几个数(前缀和统计 1 的个数)
            cur = n - i; // 重新定位
            pos[queries[i]] = cur;
            bit.update(cur, 1);
        }
        return ans;
    }
}

★315. 计算右侧小于当前元素的个数

「离散化」:把原序列的值域映射到一个连续的整数区间,并保证它们的偏序关系不变。

  • 逆序遍历 nums 读取排名;
  • 先查询严格小于当前排名的「前缀和」,即严格小于当前排名的元素的个数,「前缀和查询」;
  • 给「当前排名」加 1,「单点更新」。
class Solution {
    public List<Integer> countSmaller(int[] nums) {
        List<Integer> res = new ArrayList();
        int n = nums.length;
        discrete(nums);
        BIT bit = new BIT(n);
        for (int i = n - 1; i >= 0; i--) {             
            int j = nums[i];
            bit.update(j + 1, 1);    
            res.add(bit.query(j));            
        }
        Collections.reverse(res);
        return res;
    }
// 离散化 改变了原数组,偏序关系不变。
    void discrete(int[] nums) {
        int n = nums.length;
        int[] tmp = Arrays.copyOf(nums,  n);
        Arrays.sort(tmp);
        for (int i = 0; i < n; i++) {
            nums[i] = Arrays.binarySearch(tmp, nums[i]);
        }
    }
}

1395. 统计作战单位数

class Solution {    
    public int numTeams(int[] rating) {
        int n = rating.length;
        discrete(rating);
        int ans = 0;
        BIT bit = new BIT(n);
        for (int i = 0; i < n; i++) {
            int x = rating[i];            
            int frontSmall = bit.query(x); // 前面比x小的个数
            int frontLarge = i - frontSmall; // 前面比x大的个数 
            int backSmall = x - frontSmall;
            int backLarge = n - 1 - i - backSmall;
            ans += frontSmall * backLarge + frontLarge * backSmall;
            bit.update(x + 1, 1); // 对应 tree 需要 + 1
        }
        return ans;
    }
}

327. 区间和的个数

对于每一个前缀和 acc,需要知道在这个前缀和之前有多少个前缀和属于 [acc - upper, acc - lower] 。
如果 x 属于这个集合中,那么一定有 acc - x 在 [lower, upper] 中。
在求出这个集合中的数量之后,将 acc 加入树状数组中。

用树状数组来记录什么?
用树状数组来记录前缀和,但前缀和的分布太极端,需要离散化。
将 acc、acc - upper、acc - lower 全部离散化。

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        int n = nums.length;
        // 前缀和数组
        long[] prefix = new long[n + 1];        
        for (int i = 0; i < n; i++) {
            prefix[i + 1] = prefix[i] + nums[i];
        }
        // 离散化 去重排序
        Set<Long> set = new TreeSet();
        for (long x : prefix) {
            set.add(x - lower);
            set.add(x - upper);
            set.add(x);
        }
        HashMap<Long, Integer> map = new HashMap();
        int rank = 1; // 编号
        for (long x : set) map.put(x, rank++);

        // 树状数组
        BIT bit = new BIT(rank);
        int ans = 0;
        for (long x : prefix) {
            int high = map.get(x - lower);
            int low = map.get(x - upper);
            rank = map.get(x);
            ans += bit.query(high) - bit.query(low - 1);
            bit.update(rank, 1);
        }
        return ans;
    }
}

2426. 满足不等式的数对数目

class Solution {
    public long numberOfPairs(int[] nums1, int[] nums2, int diff) {
        // ai <= diff + aj
        int N = 60000, n = nums1.length;
        BIT bit = new BIT(N);
        long count = 0;
        for (int i = 0; i < n; i++) {
            int x = nums1[i] - nums2[i] + 20002; // + 偏移量, x > 0
            count += bit.query(x + diff);
            bit.update(x, 1);
        }
        return count;
    }
}

2250. 统计包含每个点的矩形数目

from sortedcontainers import SortedList

# 一个维度排序,有序容器维护另一个维度
class Solution:
    def countRectangles(self, rect: List[List[int]], points: List[List[int]]) -> List[int]:
        m = len(rect)
        n = len(points)
        rect.sort()
        index = list(range(n))
        index.sort(key=lambda i:-points[i][0])

        sl = SortedList()
        res, right = [0] * n, m - 1
        for i in index:
            x, y = points[i]
            while right >= 0 and rect[right][0] >= x:
                sl.add(rect[right][1])
                right -= 1
            res[i] = len(sl) - sl.bisect_left(y)
        return res
class Solution {
    public int[] countRectangles(int[][] rec, int[][] points) {
        int n = points.length, m = rec.length, N = 100;
      
        Arrays.sort(rec, (a, b) -> a[0] - b[0]);
        Integer[] index = new Integer[n];
        Arrays.setAll(index, i -> i);
        Arrays.sort(index, (i, j) -> points[j][0] - points[i][0]);   
        
        BIT bit = new BIT(N); 
        int[] res = new int[n];
        int i = m - 1;
        for (int idx : index){
            int x = points[idx][0];
            int y = points[idx][1];
           
            while (i >= 0 && x <= rec[i][0])
                bit.update(rec[i--][1], 1);
                
            res[idx] = bit.query(N) - bit.query(y - 1);   
        }
        return res;
    }
}

493. 翻转对

327 题,对前缀和数组的每一个元素 preSum[i],找出所有位于 i 左侧的下标 j 的数量,要求 j 满足 preSum[j]∈[preSum[i]−upper, preSum[i]−lower]。

逆序遍历,对每一个元素 nums[i],查询小于它的个数,更新 2 * nums[i]。即,对数组中的每一个元素 sum[i],找出位于 i 左侧,且满足 nums[j] > 2⋅nums[i] 的下标 j。

二者都是要对数组中的每一个元素,统计「在它左侧,且取值位于某个区间」的元素数量。两个问题唯一的区别仅仅在于取值区间的不同。

由于数组中整数的范围可能很大,需要利用哈希表将所有可能出现的整数,映射到连续的整数区间内。

class Solution {
	// 二分查找插入
    public int reversePairs(int[] nums) {
        List<Long> list = new ArrayList<>(); 
        int ans = 0;
        for (int i = nums.length - 1; i >= 0; i--) {
            long x = nums[i], y = 2 * x;
            ans += binSearch(list, x); // < x 的个数
            list.add(binSearch(list, y), y);
        }
        return ans;
    }

    private int binSearch(List<Long> list, long target) {
        int l = 0, r = list.size();
        while (l < r) {
            int mid = l + r >> 1;
            if (list.get(mid) < target) l = mid + 1;
            else r = mid;
        }
        return l;
    }
}

// 数状数组
class Solution {
    public int reversePairs(int[] nums) {    
    	// 离散化 可能出现的数全部编号
        Set<Long> set = new TreeSet();
        for(int x : nums){
            set.add(x * 1L);
            set.add(x * 2L);
        }
        int idx = 1;
        Map<Long, Integer> map = new HashMap();
        for (long x : set) map.put(x, idx++);
        
        BIT bit = new BIT(idx);
        int n = nums.length, ans = 0;
        // 数组的元素不参与更新,只是用来查找。
        for(int i = n - 1; i >= 0; i--){
            int pos = map.get(nums[i] * 1L); 
            ans += bit.query(pos); // < pos 的前缀和,即个数。 
            bit.update(map.get(2L * nums[i]) + 1, 1); // +1
        }
        return ans;
    }
}

2407. 最长递增子序列 II

dp[v] 以值 v 为结尾的子序列的最大长度
dp[v] = max(max(dp[v - k], dp[v - k + 1], … dp[v - 1]) + 1, dp[v])
用树状数组 tree[N] 维护 dp[i], i = [v - k, v - 1] 的最大值, 令 val[i] = dp[i] 即可。

class Solution {
    int[] tree, dp;
    public int lengthOfLIS(int[] nums, int k) {
        // int n = IntStream.of(nums).max().getAsInt() + 1;
        int max = 1;
        for (int x : nums) if (x > max) max = x;
        int n = max + 1; // n = 100001;
        tree = new int[n];
        dp = new int[n];
        int ans = 1;
        for (int x : nums) {
            int v = query(Math.max(1, x - k), x - 1);
            dp[x] = v + 1;
            // ans = Math.max(ans, v + 1);
            update(x);
        }
        // return ans;
        return query(1, n - 1);
    }

    private void update(int x) {
        for (int i = x; i < dp.length; i += lowBit(i)) {
            tree[i] = Math.max(tree[i], dp[x]);
        }
    }

    private int query(int l, int r) {
        int ans = 0;
        while(r >= l){
            ans = Math.max(dp[r], ans);
            r--;
            for(; r - lowBit(r) >= l; r -= lowBit(r)){
                ans = Math.max(tree[r], ans);
            }
        }
        return ans;
    }

    private int lowBit(int x){
        return x & -x;
    }
}

1964. 找出到每个位置为止最长的有效障碍赛跑路线

class Solution {
    public int[] longestObstacleCourseAtEachPosition(int[] obstacles) {
        int n = obstacles.length;
        int[] ans = new int[n];
        List<Integer> list = new ArrayList();
        for (int i = 0; i < n; i++) {
            int x = obstacles[i];
            int m = list.size();
            if (m == 0 || x >= list.get(m - 1)) {
                list.add(x);
                ans[i] = m + 1;
            } else {
                int j = binSearch(list, x);
                ans[i] = j + 1;
                list.set(j, x);
            }
        }
        return ans;
    }

    private int binSearch(List<Integer> list, int target) {
        int l = 0, r = list.size();
        while (l < r) {
            int mid = l + r >> 1;
            if (list.get(mid) <= target) l = mid + 1;
            else r = mid;
        }
        return l;
    }
}
class Solution {
    int[] tree;
    int n;
    public int[] longestObstacleCourseAtEachPosition(int[] obstacles) {
        this.n = obstacles.length;
        this.tree = new int[n + 1];
        int[] temp = new int[n];
        System.arraycopy(obstacles, 0, temp, 0, n);
        Arrays.sort(obstacles);
        for (int i = 0; i < n; i++) 
            temp[i] = Arrays.binarySearch(obstacles, temp[i]);
        int[] ans = new int[n];
        for (int i = 0; i < n; i++) {
            ans[i] = query(temp[i] + 1) + 1;
            add(temp[i] + 1, ans[i]);
        }
        return ans;
    }

    public void add(int x, int y) {
        while (x <= n) {
            tree[x] = Math.max(tree[x], y);
            x += lowbit(x);
        }
    }

    public int query(int x) {
        int sum = 0;
        while (x > 0) {
            sum = Math.max(sum, tree[x]);
            x -= lowbit(x);
        }
        return sum;
    }

    public int lowbit(int x) {
        return x & (-x);
    }
}

▲1505. 最多 K 次交换相邻数位后得到的最小整数

class Solution {
    public String minInteger(String num, int k) {
        int n = num.length();    
        char[] arr = num.toCharArray(); 
        StringBuilder sb = new StringBuilder();
        Queue<Integer>[] q = new Queue[10];
        for (int i = 0; i < 10; i++) q[i] = new LinkedList<>();
        for (int i = 0; i < n; i++) q[arr[i] - '0'].add(i);
        BIT bit = new BIT(n);        
        while (k > 0) {
            boolean flag = true;
            for (int j = 0; j < 10; j++) {
                if (q[j].isEmpty()) continue;
                int idx = q[j].peek();
                // "716423" 1 -> 1 步(1 的下标),2 -> 4(下标) - 1(4 前已经移动的个数),3 -> 5 - 2   
                // 一个数不考虑前面的移到,下标 idx 表示移到开头需要的步数;需要减去小于 idx 已移动的个数。             
                int cost = idx - bit.query(idx);
                if (cost <= k) {
                    q[j].poll();
                    arr[idx] = '*';
                    bit.update(idx + 1, 1);
                    k -= cost;
                    sb.append((char) (j + '0'));
                    flag = false;
                    break;
                }
            }
            if (flag) break;
        }
        for (int i = 0; i < n; i++){
            if (arr[i] != '*') sb.append(arr[i]);
        }
        return sb.toString();
    }
}

673. 最长递增子序列的个数

218. 天际线问题

2179. 统计数组中好三元组数目

1、哈希表记录 nums2 每个数的位置
2、遍历 nums1,当前数字作为三元组中间数字。
3、第一个数字同时在左侧;第三个数字同时在右侧。

class Solution {
    public long goodTriplets(int[] nums1, int[] nums2) {
        int n = nums1.length;
        long res = 0;
        BIT bit = new BIT(n);
        // 记录 nums2 元素的位置
        Map<Integer, Integer> map = new HashMap<>();
        for (int i = 0; i < n; i++) map.put(nums2[i], i);

        // 遍历 nums1, x = nums1[i],获取 x 在 nums2 中的位置 j 
        for (int i = 0; i < n; i++) {
            int j = map.get(nums1[i]);
            // nums1 中的前 i 个元素,在 nums2 中,落在 j 前面的元素个数 left,落在 j 后面的个数为 i - left。同时在 x 右侧的个数为 n - 1 - j - (i - left)
            int left = bit.query(j);
            int right = n - 1 + left - i - j;
            res += (long) left * right;
            // 标记 j 已经遍历过
            bit.update(j + 1, 1);
        }
        return res;
    }
}

2286. 以组为单位订音乐会的门票

2424. 最长上传前缀

class LUPrefix {
    boolean[] vis;
    int x = 0;
    public LUPrefix(int n) {
        vis = new boolean[n + 2];
    }
    
    public void upload(int video) {
        vis[video] = true;
    }
    
    public int longest() {
        // 均摊复杂度 O(1)
        while (vis[x + 1]) x++;
        return x;
    }
}

class LUPrefix {
    Uf uf;
    boolean[] vis;
    public LUPrefix(int n) {
        uf = new Uf(n + 1);
        vis = new boolean[n + 2];
    }
    
    public void upload(int video) {
        vis[video] = true;
        if (vis[video - 1]) uf.union(video, video - 1);
        if (vis[video + 1]) uf.union(video + 1, video);
    }
    
    public int longest() {
        if (!vis[1]) return 0;
        return uf.getval();
    }
}

class Uf{
    int[] parent, size;
    Uf(int n){
        parent = new int[n + 1];
        Arrays.setAll(parent, i -> i);
        size = new int[n + 1];
        Arrays.fill(size, 1);
    }
    void union(int x, int y){
        x = find(x);
        y = find(y);
        if (x == y) return;
        parent[x] = y;
        size[y] += size[x];
    }
    int find(int x){
        if (x != parent[x]) parent[x] = find(parent[x]);
        return parent[x];
    }
    int getval(){
        return size[find(1)];
    }    
}

2193. 得到回文串的最少操作次数

class Solution {
    public int minMovesToMakePalindrome(String s) {        
        StringBuilder sb = new StringBuilder(s);        
        int n = s.length(), res = 0;
        while (n > 2){
            int j = sb.indexOf(sb.charAt(n - 1) + "");
            if (j == n - 1) res += n-- / 2;               
            else {
                res += j;
                sb.deleteCharAt(j);                
                n -= 2;
            }
        } 
        return res;        
    }
}

406. 根据身高重建队列

class Solution {
    public int[][] reconstructQueue(int[][] people) {
        // 身高 ↓ 个数 ↑
        Arrays.sort(people, (a, b) -> a[0] == b[0] ? a[1] - b[1] : b[0] - a[0]);
        List<int[]> res = new LinkedList();
        // 先处理高个,不影响矮个的插入点。
        for (int[] p : people){
            res.add(p[1], p);
        }         
        return res.toArray(new int[0][]); 
    }
}