树状数组
算法训练营 树状数组 (Binary Indexed Tree(B.I.T), Fenwick Tree) 是一个查询和修改复杂度都为 log(n) 的数据结构。
「前缀和查询」与「单点更新」
直接前驱:c[i] 的直接前驱为 c[i - lowbid(i)],即 c[i] 左侧紧邻的子树的根。
直接后继:c[i] 的直接前驱为 c[i + lowbid(i)],即 c[i] 的父结点。
前驱:c[i] 左侧所有子树的根。
后继:c[i] 的所有祖先。
★307. 区域和检索 - 数组可修改
「单点修改」和「区间查询」。[L,R] 区间和 = R 的前缀和减去 L - 1 的前缀和。
class NumArray {
BIT bit;
int[] nums;
public NumArray(int[] nums) {
int n = nums.length;
this.nums = nums;
bit = new BIT(n);
for (int i = 0; i < n; i++){
bit.update(i + 1, nums[i]);
}
}
public void update(int i, int val) {
bit.update(i + 1, val - nums[i]);
nums[i] = val;
}
public int sumRange(int left, int right) {
return bit.query(right + 1) - bit.query(left);
}
}
// 数状数组模板
class BIT {
private int[] tree;
private int n;
public BIT(int n) {
this.n = n;
tree = new int[n + 1];
}
// 单点更新
public void update(int i, int delta) {
for (; i <= n; i += lowbit(i)) tree[i] += delta;
}
// 区间查询 前缀和
public int query(int i) {
int sum = 0;
for (; i > 0; i -= lowbit(i)) sum += tree[i];
return sum;
}
int lowbit(int x) {
return x & (-x);
}
}
时间复杂度:add 操作和 query 的复杂度都是 O(logn),因此构建数组的复杂度为 O(nlogn)。整体复杂度为 O(nlogn)
空间复杂度:O(n)
1649. 通过指令创建有序数组
class Solution {
public int createSortedArray(int[] instructions) {
int max_ = 1;
for (int x : instructions) if (x > max_) max_ = x;
int mod = (int) 1e9 + 7, n = instructions.length;
long ans = 0;
BIT bit = new BIT(max_);
for (int i = 0; i < n; i++) {
int x = instructions[i];
int left = bit.query(x - 1);
int mid = bit.query(x);
int right = i - mid;
ans += Math.min(left, right);
bit.update(x, 1);
}
return (int) (ans % mod);
}
}
▲1409. 查询带键的排列
class Solution {
public int[] processQueries(int[] queries, int m) {
int n = queries.length;
BIT bit = new BIT(m + n); // 前面空出 n 个位置
int[] pos = new int[m + 1];
for (int i = 1; i <= m; ++i) {
pos[i] = n + i; // 右移 n
bit.update(n + i, 1);
}
int[] ans = new int[n];
for (int i = 0; i < n; ++i) {
int cur = pos[queries[i]];
bit.update(cur, -1); // 移动
ans[i] = bit.query(cur); // 位置 = 当前位置前面有几个数(前缀和统计 1 的个数)
cur = n - i; // 重新定位
pos[queries[i]] = cur;
bit.update(cur, 1);
}
return ans;
}
}
★315. 计算右侧小于当前元素的个数
「离散化」:把原序列的值域映射到一个连续的整数区间,并保证它们的偏序关系不变。
- 逆序遍历 nums 读取排名;
- 先查询严格小于当前排名的「前缀和」,即严格小于当前排名的元素的个数,「前缀和查询」;
- 给「当前排名」加 1,「单点更新」。
class Solution {
public List<Integer> countSmaller(int[] nums) {
List<Integer> res = new ArrayList();
int n = nums.length;
discrete(nums);
BIT bit = new BIT(n);
for (int i = n - 1; i >= 0; i--) {
int j = nums[i];
bit.update(j + 1, 1);
res.add(bit.query(j));
}
Collections.reverse(res);
return res;
}
// 离散化 改变了原数组,偏序关系不变。
void discrete(int[] nums) {
int n = nums.length;
int[] tmp = Arrays.copyOf(nums, n);
Arrays.sort(tmp);
for (int i = 0; i < n; i++) {
nums[i] = Arrays.binarySearch(tmp, nums[i]);
}
}
}
1395. 统计作战单位数
class Solution {
public int numTeams(int[] rating) {
int n = rating.length;
discrete(rating);
int ans = 0;
BIT bit = new BIT(n);
for (int i = 0; i < n; i++) {
int x = rating[i];
int frontSmall = bit.query(x); // 前面比x小的个数
int frontLarge = i - frontSmall; // 前面比x大的个数
int backSmall = x - frontSmall;
int backLarge = n - 1 - i - backSmall;
ans += frontSmall * backLarge + frontLarge * backSmall;
bit.update(x + 1, 1); // 对应 tree 需要 + 1
}
return ans;
}
}
327. 区间和的个数
对于每一个前缀和 acc,需要知道在这个前缀和之前有多少个前缀和属于 [acc - upper, acc - lower] 。
如果 x 属于这个集合中,那么一定有 acc - x 在 [lower, upper] 中。
在求出这个集合中的数量之后,将 acc 加入树状数组中。
用树状数组来记录什么?
用树状数组来记录前缀和,但前缀和的分布太极端,需要离散化。
将 acc、acc - upper、acc - lower 全部离散化。
class Solution {
public int countRangeSum(int[] nums, int lower, int upper) {
int n = nums.length;
// 前缀和数组
long[] prefix = new long[n + 1];
for (int i = 0; i < n; i++) {
prefix[i + 1] = prefix[i] + nums[i];
}
// 离散化 去重排序
Set<Long> set = new TreeSet();
for (long x : prefix) {
set.add(x - lower);
set.add(x - upper);
set.add(x);
}
HashMap<Long, Integer> map = new HashMap();
int rank = 1; // 编号
for (long x : set) map.put(x, rank++);
// 树状数组
BIT bit = new BIT(rank);
int ans = 0;
for (long x : prefix) {
int high = map.get(x - lower);
int low = map.get(x - upper);
rank = map.get(x);
ans += bit.query(high) - bit.query(low - 1);
bit.update(rank, 1);
}
return ans;
}
}
2426. 满足不等式的数对数目
class Solution {
public long numberOfPairs(int[] nums1, int[] nums2, int diff) {
// ai <= diff + aj
int N = 60000, n = nums1.length;
BIT bit = new BIT(N);
long count = 0;
for (int i = 0; i < n; i++) {
int x = nums1[i] - nums2[i] + 20002; // + 偏移量, x > 0
count += bit.query(x + diff);
bit.update(x, 1);
}
return count;
}
}
2250. 统计包含每个点的矩形数目
from sortedcontainers import SortedList
# 一个维度排序,有序容器维护另一个维度
class Solution:
def countRectangles(self, rect: List[List[int]], points: List[List[int]]) -> List[int]:
m = len(rect)
n = len(points)
rect.sort()
index = list(range(n))
index.sort(key=lambda i:-points[i][0])
sl = SortedList()
res, right = [0] * n, m - 1
for i in index:
x, y = points[i]
while right >= 0 and rect[right][0] >= x:
sl.add(rect[right][1])
right -= 1
res[i] = len(sl) - sl.bisect_left(y)
return res
class Solution {
public int[] countRectangles(int[][] rec, int[][] points) {
int n = points.length, m = rec.length, N = 100;
Arrays.sort(rec, (a, b) -> a[0] - b[0]);
Integer[] index = new Integer[n];
Arrays.setAll(index, i -> i);
Arrays.sort(index, (i, j) -> points[j][0] - points[i][0]);
BIT bit = new BIT(N);
int[] res = new int[n];
int i = m - 1;
for (int idx : index){
int x = points[idx][0];
int y = points[idx][1];
while (i >= 0 && x <= rec[i][0])
bit.update(rec[i--][1], 1);
res[idx] = bit.query(N) - bit.query(y - 1);
}
return res;
}
}
493. 翻转对
327 题,对前缀和数组的每一个元素 preSum[i],找出所有位于 i 左侧的下标 j 的数量,要求 j 满足 preSum[j]∈[preSum[i]−upper, preSum[i]−lower]。
逆序遍历,对每一个元素 nums[i],查询小于它的个数,更新 2 * nums[i]。即,对数组中的每一个元素 sum[i],找出位于 i 左侧,且满足 nums[j] > 2⋅nums[i] 的下标 j。
二者都是要对数组中的每一个元素,统计「在它左侧,且取值位于某个区间」的元素数量。两个问题唯一的区别仅仅在于取值区间的不同。
由于数组中整数的范围可能很大,需要利用哈希表将所有可能出现的整数,映射到连续的整数区间内。
class Solution {
// 二分查找插入
public int reversePairs(int[] nums) {
List<Long> list = new ArrayList<>();
int ans = 0;
for (int i = nums.length - 1; i >= 0; i--) {
long x = nums[i], y = 2 * x;
ans += binSearch(list, x); // < x 的个数
list.add(binSearch(list, y), y);
}
return ans;
}
private int binSearch(List<Long> list, long target) {
int l = 0, r = list.size();
while (l < r) {
int mid = l + r >> 1;
if (list.get(mid) < target) l = mid + 1;
else r = mid;
}
return l;
}
}
// 数状数组
class Solution {
public int reversePairs(int[] nums) {
// 离散化 可能出现的数全部编号
Set<Long> set = new TreeSet();
for(int x : nums){
set.add(x * 1L);
set.add(x * 2L);
}
int idx = 1;
Map<Long, Integer> map = new HashMap();
for (long x : set) map.put(x, idx++);
BIT bit = new BIT(idx);
int n = nums.length, ans = 0;
// 数组的元素不参与更新,只是用来查找。
for(int i = n - 1; i >= 0; i--){
int pos = map.get(nums[i] * 1L);
ans += bit.query(pos); // < pos 的前缀和,即个数。
bit.update(map.get(2L * nums[i]) + 1, 1); // +1
}
return ans;
}
}
2407. 最长递增子序列 II
dp[v] 以值 v 为结尾的子序列的最大长度
dp[v] = max(max(dp[v - k], dp[v - k + 1], … dp[v - 1]) + 1, dp[v])
用树状数组 tree[N] 维护 dp[i], i = [v - k, v - 1] 的最大值, 令 val[i] = dp[i] 即可。
class Solution {
int[] tree, dp;
public int lengthOfLIS(int[] nums, int k) {
// int n = IntStream.of(nums).max().getAsInt() + 1;
int max = 1;
for (int x : nums) if (x > max) max = x;
int n = max + 1; // n = 100001;
tree = new int[n];
dp = new int[n];
int ans = 1;
for (int x : nums) {
int v = query(Math.max(1, x - k), x - 1);
dp[x] = v + 1;
// ans = Math.max(ans, v + 1);
update(x);
}
// return ans;
return query(1, n - 1);
}
private void update(int x) {
for (int i = x; i < dp.length; i += lowBit(i)) {
tree[i] = Math.max(tree[i], dp[x]);
}
}
private int query(int l, int r) {
int ans = 0;
while(r >= l){
ans = Math.max(dp[r], ans);
r--;
for(; r - lowBit(r) >= l; r -= lowBit(r)){
ans = Math.max(tree[r], ans);
}
}
return ans;
}
private int lowBit(int x){
return x & -x;
}
}
1964. 找出到每个位置为止最长的有效障碍赛跑路线
class Solution {
public int[] longestObstacleCourseAtEachPosition(int[] obstacles) {
int n = obstacles.length;
int[] ans = new int[n];
List<Integer> list = new ArrayList();
for (int i = 0; i < n; i++) {
int x = obstacles[i];
int m = list.size();
if (m == 0 || x >= list.get(m - 1)) {
list.add(x);
ans[i] = m + 1;
} else {
int j = binSearch(list, x);
ans[i] = j + 1;
list.set(j, x);
}
}
return ans;
}
private int binSearch(List<Integer> list, int target) {
int l = 0, r = list.size();
while (l < r) {
int mid = l + r >> 1;
if (list.get(mid) <= target) l = mid + 1;
else r = mid;
}
return l;
}
}
class Solution {
int[] tree;
int n;
public int[] longestObstacleCourseAtEachPosition(int[] obstacles) {
this.n = obstacles.length;
this.tree = new int[n + 1];
int[] temp = new int[n];
System.arraycopy(obstacles, 0, temp, 0, n);
Arrays.sort(obstacles);
for (int i = 0; i < n; i++)
temp[i] = Arrays.binarySearch(obstacles, temp[i]);
int[] ans = new int[n];
for (int i = 0; i < n; i++) {
ans[i] = query(temp[i] + 1) + 1;
add(temp[i] + 1, ans[i]);
}
return ans;
}
public void add(int x, int y) {
while (x <= n) {
tree[x] = Math.max(tree[x], y);
x += lowbit(x);
}
}
public int query(int x) {
int sum = 0;
while (x > 0) {
sum = Math.max(sum, tree[x]);
x -= lowbit(x);
}
return sum;
}
public int lowbit(int x) {
return x & (-x);
}
}
▲1505. 最多 K 次交换相邻数位后得到的最小整数
class Solution {
public String minInteger(String num, int k) {
int n = num.length();
char[] arr = num.toCharArray();
StringBuilder sb = new StringBuilder();
Queue<Integer>[] q = new Queue[10];
for (int i = 0; i < 10; i++) q[i] = new LinkedList<>();
for (int i = 0; i < n; i++) q[arr[i] - '0'].add(i);
BIT bit = new BIT(n);
while (k > 0) {
boolean flag = true;
for (int j = 0; j < 10; j++) {
if (q[j].isEmpty()) continue;
int idx = q[j].peek();
// "716423" 1 -> 1 步(1 的下标),2 -> 4(下标) - 1(4 前已经移动的个数),3 -> 5 - 2
// 一个数不考虑前面的移到,下标 idx 表示移到开头需要的步数;需要减去小于 idx 已移动的个数。
int cost = idx - bit.query(idx);
if (cost <= k) {
q[j].poll();
arr[idx] = '*';
bit.update(idx + 1, 1);
k -= cost;
sb.append((char) (j + '0'));
flag = false;
break;
}
}
if (flag) break;
}
for (int i = 0; i < n; i++){
if (arr[i] != '*') sb.append(arr[i]);
}
return sb.toString();
}
}
673. 最长递增子序列的个数
218. 天际线问题
2179. 统计数组中好三元组数目
1、哈希表记录 nums2 每个数的位置
2、遍历 nums1,当前数字作为三元组中间数字。
3、第一个数字同时在左侧;第三个数字同时在右侧。
class Solution {
public long goodTriplets(int[] nums1, int[] nums2) {
int n = nums1.length;
long res = 0;
BIT bit = new BIT(n);
// 记录 nums2 元素的位置
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < n; i++) map.put(nums2[i], i);
// 遍历 nums1, x = nums1[i],获取 x 在 nums2 中的位置 j
for (int i = 0; i < n; i++) {
int j = map.get(nums1[i]);
// nums1 中的前 i 个元素,在 nums2 中,落在 j 前面的元素个数 left,落在 j 后面的个数为 i - left。同时在 x 右侧的个数为 n - 1 - j - (i - left)
int left = bit.query(j);
int right = n - 1 + left - i - j;
res += (long) left * right;
// 标记 j 已经遍历过
bit.update(j + 1, 1);
}
return res;
}
}
2286. 以组为单位订音乐会的门票
2424. 最长上传前缀
class LUPrefix {
boolean[] vis;
int x = 0;
public LUPrefix(int n) {
vis = new boolean[n + 2];
}
public void upload(int video) {
vis[video] = true;
}
public int longest() {
// 均摊复杂度 O(1)
while (vis[x + 1]) x++;
return x;
}
}
class LUPrefix {
Uf uf;
boolean[] vis;
public LUPrefix(int n) {
uf = new Uf(n + 1);
vis = new boolean[n + 2];
}
public void upload(int video) {
vis[video] = true;
if (vis[video - 1]) uf.union(video, video - 1);
if (vis[video + 1]) uf.union(video + 1, video);
}
public int longest() {
if (!vis[1]) return 0;
return uf.getval();
}
}
class Uf{
int[] parent, size;
Uf(int n){
parent = new int[n + 1];
Arrays.setAll(parent, i -> i);
size = new int[n + 1];
Arrays.fill(size, 1);
}
void union(int x, int y){
x = find(x);
y = find(y);
if (x == y) return;
parent[x] = y;
size[y] += size[x];
}
int find(int x){
if (x != parent[x]) parent[x] = find(parent[x]);
return parent[x];
}
int getval(){
return size[find(1)];
}
}
2193. 得到回文串的最少操作次数
class Solution {
public int minMovesToMakePalindrome(String s) {
StringBuilder sb = new StringBuilder(s);
int n = s.length(), res = 0;
while (n > 2){
int j = sb.indexOf(sb.charAt(n - 1) + "");
if (j == n - 1) res += n-- / 2;
else {
res += j;
sb.deleteCharAt(j);
n -= 2;
}
}
return res;
}
}
406. 根据身高重建队列
class Solution {
public int[][] reconstructQueue(int[][] people) {
// 身高 ↓ 个数 ↑
Arrays.sort(people, (a, b) -> a[0] == b[0] ? a[1] - b[1] : b[0] - a[0]);
List<int[]> res = new LinkedList();
// 先处理高个,不影响矮个的插入点。
for (int[] p : people){
res.add(p[1], p);
}
return res.toArray(new int[0][]);
}
}