目录
- 1. 最大异或结点
- 1. 问题描述
- 2. 输入格式
- 3. 输出格式
- 4. 样例输入
- 5. 样例输出
- 6. 样例说明
- 7. 评测用例规模与约定
- 2. 解题思路
- 1. 解题思路
- 2. AC_Code
1. 最大异或结点
1. 问题描述
小蓝有一棵树,树中包含 个结点,编号为 ,其中每个结点上都有一个整数 。他可以从树中任意选择两个不直接相连的结点 并获得分数 ,其中
请问小蓝可以获得的最大分数是多少?
2. 输入格式
输入的第一行包含一个整数 ,表示有
第二行包含 个整数 ,相邻整数之间使用一个空格分隔。
第三行包含 个整数 ,相邻整数之间使用一个空格分隔,其中第 个整数表示 的父结点编号, 表示结点
3. 输出格式
输出一行包含一个整数表示答案。
4. 样例输入
5
1 0 5 3 4
-1 0 1 0 1
5. 样例输出
7
6. 样例说明
选择编号为 和 的结点,,他们的值异或后的结果为
7. 评测用例规模与约定
对于 的评测用例,;
对于所有评测用例,。
2. 解题思路
1. 解题思路
- 暴力做法
直接枚举所有可能选择的组合,即枚举选择的 和 。同时需要判断 和
枚举的复杂度为 ,判断相邻的复杂度为 ,整体复杂度为 ,无法通过本题。
- 满分做法
考虑优化。对于需要选择两个元素 和 的题目,常见的套路是枚举 ,并从剩余元素中选择最优元素作为 。在本题中,当 确定时,我们需要从剩余元素中找到最优元素 使得 最大,这实际上是一个 字典树的典型应用。如果你对 字典树还不太熟悉,可以通过 01字典树 学习。
问题在于,当我们枚举 时,字典树中不能包含
一个直观的想法是当我们枚举到 时,先将字典树中所有
思考:这样操作的复杂度是否可行?
显然是可行的。因为本题给定的是一棵树,对于每条边而言,假设其两端的点为 和 ,当我们枚举 时, 会产生一次删除和插入;枚举到 时, 会产生一次删除和插入。由于一棵树只有 条边,总共产生的删除和插入操作为 次。忽略常数,这部分复杂度视为 。
考虑到字典树每次插入、删除、查询的复杂度均为 ,其中 表示值域的最大值,整体复杂度为 ,可以通过本题。
2. AC_Code
- C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define sz(s) ((int)s.size())
class Node {
public:
array<Node *, 2> children{};
int cnt = 0;
};
class Trie {
static const int HIGH_BIT = 31;
public:
Node *root = new Node();
void insert(ll val) {
Node *cur = root;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (val >> i) & 1;
if (cur->children[bit] == nullptr) {
cur->children[bit] = new Node();
}
cur = cur->children[bit];
cur->cnt++;
}
}
void remove(ll val) {
Node *cur = root;
for (int i = HIGH_BIT; i >= 0; i--) {
cur = cur->children[(val >> i) & 1];
cur->cnt--;
}
}
int max_xor(ll val) {
Node *cur = root;
int ans = 0;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (val >> i) & 1;
if (cur->children[bit ^ 1] && cur->children[bit ^ 1]->cnt) {
ans |= 1 << i;
bit ^= 1;
}
cur = cur->children[bit];
}
return ans;
}
int min_xor(ll val) {
Node *cur = root;
int ans = 0;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (val >> i) & 1;
if (cur->children[bit] && cur->children[bit]->cnt) {
cur = cur->children[bit];
} else {
ans |= 1 << i;
cur = cur->children[bit ^ 1];
}
}
return ans;
}
};
void solve() {
int n;
cin >> n;
vector<int> a(n);
Trie tr{};
for (int i = 0; i < n; ++i) {
cin >> a[i];
tr.insert(a[i]);
}
vector<vector<int>> adj(n);
for (int i = 0; i < n; ++i) {
int f;
cin >> f;
if (f != -1) {
adj[i].push_back(f);
adj[f].push_back(i);
}
}
int ans = 0;
for (int i = 0; i < n; ++i) {
for (auto v : adj[i]) {
tr.remove(a[v]);
}
ans = max(ans, tr.max_xor(a[i]));
for (auto v : adj[i]) {
tr.insert(a[v]);
}
}
cout << ans << '\n';
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
cout << setiosflags(ios::fixed) << setprecision(10);
int t = 1;
while (t--) {
solve();
}
return 0;
}
- Java
import java.util.*;
class Node {
Node[] children = new Node[2];
int cnt = 0;
}
class Trie {
private static final int HIGH_BIT = 31;
Node root = new Node();
void insert(long val) {
Node cur = root;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (int) ((val >> i) & 1);
if (cur.children[bit] == null) {
cur.children[bit] = new Node();
}
cur = cur.children[bit];
cur.cnt++;
}
}
void remove(long val) {
Node cur = root;
for (int i = HIGH_BIT; i >= 0; i--) {
cur = cur.children[(int) ((val >> i) & 1)];
cur.cnt--;
}
}
int maxXor(long val) {
Node cur = root;
int ans = 0;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (int) ((val >> i) & 1);
if (cur.children[bit ^ 1] != null && cur.children[bit ^ 1].cnt > 0) {
ans |= 1 << i;
bit ^= 1;
}
cur = cur.children[bit];
}
return ans;
}
int minXor(long val) {
Node cur = root;
int ans = 0;
for (int i = HIGH_BIT; i >= 0; i--) {
int bit = (int) ((val >> i) & 1);
if (cur.children[bit] != null && cur.children[bit].cnt > 0) {
cur = cur.children[bit];
} else {
ans |= 1 << i;
cur = cur.children[bit ^ 1];
}
}
return ans;
}
}
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
Trie tr = new Trie();
for (int i = 0; i < n; ++i) {
a[i] = sc.nextInt();
tr.insert(a[i]);
}
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i < n; ++i) {
adj.add(new ArrayList<>());
}
for (int i = 0; i < n; ++i) {
int f = sc.nextInt();
if (f != -1) {
adj.get(i).add(f);
adj.get(f).add(i);
}
}
int ans = 0;
for (int i = 0; i < n; ++i) {
for (int v : adj.get(i)) {
tr.remove(a[v]);
}
ans = Math.max(ans, tr.maxXor(a[i]));
for (int v : adj.get(i)) {
tr.insert(a[v]);
}
}
System.out.println(ans);
}
}
- Python
class Node:
def __init__(self):
self.children = [None, None]
self.cnt = 0
class Trie:
HIGH_BIT = 31
def __init__(self):
self.root = Node()
def insert(self, val):
cur = self.root
for i in range(self.HIGH_BIT, -1, -1):
bit = (val >> i) & 1
if cur.children[bit] is None:
cur.children[bit] = Node()
cur = cur.children[bit]
cur.cnt += 1
def remove(self, val):
cur = self.root
for i in range(self.HIGH_BIT, -1, -1):
bit = (val >> i) & 1
cur = cur.children[bit]
cur.cnt -= 1
def max_xor(self, val):
cur = self.root
ans = 0
for i in range(self.HIGH_BIT, -1, -1):
bit = (val >> i) & 1
if cur.children[bit ^ 1] and cur.children[bit ^ 1].cnt > 0:
ans |= 1 << i
bit ^= 1
cur = cur.children[bit]
return ans
def min_xor(self, val):
cur = self.root
ans = 0
for i in range(self.HIGH_BIT, -1, -1):
bit = (val >> i) & 1
if cur.children[bit] and cur.children[bit].cnt > 0:
cur = cur.children[bit]
else:
ans |= 1 << i
cur = cur.children[bit ^ 1]
return ans
def solve():
import sys
input = sys.stdin.read
data = input().split()
idx = 0
n = int(data[idx])
idx += 1
a = []
tr = Trie()
for i in range(n):
a.append(int(data[idx]))
tr.insert(a[-1])
idx += 1
adj = [[] for _ in range(n)]
for i in range(n):
f = int(data[idx])
idx += 1
if f != -1:
adj[i].append(f)
adj[f].append(i)
ans = 0
for i in range(n):
for v in adj[i]:
tr.remove(a[v])
ans = max(ans, tr.max_xor(a[i]))
for v in adj[i]:
tr.insert(a[v])
print(ans)
if __name__ == "__main__":
solve()