小 Q 与树
给定一棵带权的树,每条边的距离都为
1
1
1,要我们求
∑
u
=
1
n
∑
v
=
1
n
m
i
n
(
a
u
,
a
v
)
d
i
s
(
u
,
v
)
\sum\limits_{u = 1} ^{n} \sum\limits_{v = 1} ^{n}min(a_u, a_v)dis(u, v)
u=1∑nv=1∑nmin(au,av)dis(u,v),
m
i
n
(
a
u
,
a
v
)
d
i
s
(
u
,
v
)
=
m
i
n
(
a
u
,
a
v
)
(
d
e
p
[
u
]
+
d
e
p
[
v
]
−
2
×
d
e
p
[
l
c
a
(
u
,
v
)
]
)
min(a_u, a_v) dis(u, v) = min(a_u, a_v)\left(dep[u] + dep[v] - 2 \times dep[lca(u, v)]\right)\\
min(au,av)dis(u,v)=min(au,av)(dep[u]+dep[v]−2×dep[lca(u,v)])
如果考虑 dsu on tree,则是枚举 u u u,分两种情况统计答案:
-
a u ≤ a v a_u \leq a_v au≤av,则 ∑ v ∈ S a u ( d e p [ u ] + d e p [ v ] − 2 × d e p [ l c a ( u , v ) ] ) \sum\limits_{v \in S} a_u(dep[u] + dep[v] - 2 \times dep[lca(u, v)]) v∈S∑au(dep[u]+dep[v]−2×dep[lca(u,v)]),则我们只要知道集合 S S S中有多少个点,以及 ∑ v ∈ S d e p [ v ] \sum\limits_{v \in S} dep[v] v∈S∑dep[v]即可,
设点的个数为 t o t a l total total, ∑ v ∈ S d e p [ v ] = S u m d e p \sum\limits_{v \in S} dep[v] = Sum_{dep} v∈S∑dep[v]=Sumdep,则上式等价于 a u × t a t a l × ( d e p [ u ] − 2 × d e p [ l c a ( u , v ) ] ) + a u × S u m d e p a_u \times tatal \times (dep[u] - 2 \times dep[lca(u, v)]) + a_u \times Sum_{dep} au×tatal×(dep[u]−2×dep[lca(u,v)])+au×Sumdep。
-
a u > a v a_u > a_v au>av,则 ∑ u ∈ S a v ( d e p [ u ] + d e p [ v ] − 2 × d e p [ l c a ( u , v ) ] ) \sum\limits_{u \in S} a_v(dep[u] + dep[v] - 2 \times dep[lca(u, v)]) u∈S∑av(dep[u]+dep[v]−2×dep[lca(u,v)]),则我们只要知道 S u m a v × d e p [ v ] Sum_{a_v \times dep[v]} Sumav×dep[v],以及 S u m a v Sum_{a_v} Sumav即可求得答案,
上式等价于 S u m a v × d e p [ v ] + S u m a v × ( d e p [ u ] − 2 × d e p [ l c a ( u , v ) ] ) Sum_{a_v \times dep[v]} + Sum_{a_v} \times (dep[u] - 2 \times dep[lca(u, v)]) Sumav×dep[v]+Sumav×(dep[u]−2×dep[lca(u,v)])。
所以可以对点权离散化,然后用线段树来维护上面需要的四个值,即可进行 dsu on tree,整体复杂度 n log n log n n \log n \log n nlognlogn。
由于上面的统计我们都是进行的单向计算,所以还要对上述计算完后的答案乘以 2 2 2即可。
#include <bits/stdc++.h>
#define ls rt << 1
#define rs rt << 1 | 1
#define mid (l + r >> 1)
#define lson ls, l, mid
#define rson rs, mid + 1, r
using namespace std;
const int N = 2e5 + 10, mod = 998244353;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int son[N], sz[N], l[N], r[N], rk[N], dep[N], tot;
int sum1[N << 2], sum2[N << 2], sum3[N << 2], sum4[N << 2];
int a[N], b[N], n, m;
inline int add(int x, int y) {
return x + y < mod ? x + y : x + y - mod;
}
inline int sub(int x, int y) {
return x >= y ? x - y : x - y + mod;
}
inline int mul(int x, int y) {
return 1ll * x * y % mod;
}
void Add(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void dfs(int rt, int fa) {
dep[rt] = dep[fa] + 1, sz[rt] = 1, l[rt] = ++tot, rk[tot] = rt;
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa) {
continue;
}
dfs(to[i], rt);
sz[rt] += sz[to[i]];
if (!son[rt] || sz[to[i]] > sz[son[rt]]) {
son[rt] = to[i];
}
}
r[rt] = tot;
}
void push_up(int rt) {
sum1[rt] = add(sum1[ls], sum1[rs]);
sum2[rt] = add(sum2[ls], sum2[rs]);
sum3[rt] = add(sum3[ls], sum3[rs]);
sum4[rt] = add(sum4[ls], sum4[rs]);
}
void update(int rt, int l, int r, int x, int v, int op) {
if (l == r) {
if (op == 1) {
sum1[rt] += 1, sum2[rt] = add(sum2[rt], v), sum3[rt] = add(sum3[rt], mul(b[x], v)), sum4[rt] = add(sum4[rt], b[x]);
}
else {
sum1[rt] -= 1, sum2[rt] = sub(sum2[rt], v), sum3[rt] = sub(sum3[rt], mul(b[x], v)), sum4[rt] = sub(sum4[rt], b[x]);
}
return ;
}
if (x <= mid) {
update(lson, x, v, op);
}
else {
update(rson, x, v, op);
}
push_up(rt);
}
int ans, ans1, ans2, ans3, ans4, ans5;
void query(int rt, int l, int r, int L, int R) {
if (l >= L && r <= R) {
ans1 = add(ans1, sum1[rt]), ans2 = add(ans2, sum2[rt]), ans3 = add(ans3, sum3[rt]), ans4 = add(ans4, sum4[rt]);
return ;
}
if (L <= mid) {
query(lson, L, R);
}
if (R > mid) {
query(rson, L, R);
}
}
void dfs(int rt, int fa, bool keep) {
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa || to[i] == son[rt]) {
continue;
}
dfs(to[i], rt, 0);
}
if (son[rt]) {
dfs(son[rt], rt, 1);
}
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa || to[i] == son[rt]) {
continue;
}
for (int j = l[to[i]]; j <= r[to[i]]; j++) {
ans1 = ans2 = ans3 = ans4 = 0;
query(1, 1, m, a[rk[j]], m);
ans = add(ans, mul(ans2, b[a[rk[j]]]));
ans = add(ans, mul(b[a[rk[j]]], mul(ans1, sub(dep[rk[j]], 2 * dep[rt]))));
if (a[rk[j]] != 1) {
ans1 = ans2 = ans3 = ans4 = 0;
query(1, 1, m, 1, a[rk[j]] - 1);
ans = add(ans, ans3);
ans = add(ans, mul(ans4, sub(dep[rk[j]], 2 * dep[rt])));
}
}
for (int j = l[to[i]]; j <= r[to[i]]; j++) {
update(1, 1, m, a[rk[j]], dep[rk[j]], 1);
}
}
ans1 = ans2 = ans3 = ans4 = 0;
query(1, 1, m, a[rt], m);
ans = add(ans, mul(ans2, b[a[rt]]));
ans = add(ans, mul(b[a[rt]], mul(ans1, sub(dep[rt], 2 * dep[rt]))));
if (a[rt] != 1) {
ans1 = ans2 = ans3 = ans4 = 0;
query(1, 1, m, 1, a[rt] - 1);
ans = add(ans, ans3);
ans = add(ans, mul(ans4, sub(dep[rt], 2 * dep[rt])));
}
update(1, 1, m, a[rt], dep[rt], 1);
if (!keep) {
for (int i = l[rt]; i <= r[rt]; i++) {
update(1, 1, m, a[rk[i]], dep[rk[i]], -1);
}
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = a[i];
}
sort(b + 1, b + 1 + n);
m = unique(b + 1, b + 1 + n) - (b + 1);
for (int i = 1; i <= n; i++) {
a[i] = lower_bound(b + 1, b + 1 + m, a[i]) - b;
}
for (int i = 1, x, y; i < n; i++) {
scanf("%d %d", &x, &y);
Add(x, y);
Add(y, x);
}
dfs(1, 0);
dfs(1, 0, 1);
printf("%d\n", mul(2, ans));
return 0;
}