题目链接:https://nanti.jisuanke.com/t/42586

题意:给一棵n个节点的树,编号1-n,每个点有点权w,问有多少组节点(x,y),满足:

1.x != y

2.x点不是y点的祖先,y点不是x点的祖先

3.x点和y点的最短距离<=k (看了题解才知道,up to k 原来是小于等于k的意思???)

4.设x和y点的公共祖先是z,val【i】表示 i 节点的点权,要求val【z】 * 2 = val【x】 + val【y】

 

思路:树上启发式合并(dsu on tree) 

为了深刻记忆,所以趁着刚学会一点,分享下自己理解的树上启发式合并。(正好用这道题来讲讲)


K - Tree  2019icpc南昌K题 (树上启发式合并 dsu on tree)_递归

 

 



 

对于这个题目,先考虑暴力的做法

假设k = 3,所有点的点权为1,看下图

K - Tree  2019icpc南昌K题 (树上启发式合并 dsu on tree)_权值_02

 

 

 

假设要算1号节点为公共节点的贡献,我们观察,2号节点和6号节点显然不能满足条件,因为他们在一条链上,但是2号节点和3,4,5号节点是可以凑出贡献的,也就是2号和红色圈圈住的子树的每一个节点都有可能凑出贡献。

这样就有了一个想法,如果我们知道红色圈住的子树的信息,我就可以直接算出左边每个点贡献(2节点和6节点)

这里用若干个线段树来维护信息,每个权值都开一棵权值线段树,维护某个权值在某个深度出现的次数。

举个例子(上面有假设所有点权值为1),假如1号节点深度为1,那么3号节点深度为2,    4,5号节点深度为3,那么对于权值为1的线段树,维护的信息就是:深度为2的点有1个,深度为3的点有2个。

算2号节点的贡献时,算出另一个匹配节点的权值应该是val【1】 * 2 - val【2】  = 1 * 2 - 1 = 1, 深度最大是: k + 2 * dep【1】 - dep【2】 = 3 + 2 - 2 = 3

所以就应该查红框的子树内,权值为1的线段树,深度区间在【1,3】的点有多少个,这里查到3个(即3,4,5号节点)。(关键)

PS:提一嘴深度最大值怎么算,假设z是x和y的最近公共祖先,则x和y的距离 = dep【x】 + dep【y】 - 2 * dep【z】,题目要求 x和y的距离<=k, 所以换个位置就是:dep【y】<= k + dep【z】 * 2 - dep【x】

算6号节点的贡献时,同理,应该查红框的子树内,权值为1的线段树,深度区间在【1,2】的点有多少个,这里查到1个(即3号节点)。

那么对于3号节点为根的子树来说,统计方法也一样,只要我知道2号节点为根子树信息,可以用一样的方法来算。

有人可能会问:那4号和5号节点怎么统计?他们会在算3号节点为公共节点的时候算,因为算最大深度的时候需要用到最近公共祖先,所以两个点若有贡献,那这两个点都应该在不同的子树上。

 

重点来了,暴力的做法就是对于每个节点,我都维护这个节点为根的子树的所有信息,即n个权值的线段树,显然空间爆炸,直接MLE

那么考虑在全局开n个权值的线段树,每个节点都用全局的线段树来维护信息,但也是空间爆炸。所以考虑动态开点,每棵线段树只开遍历到的点。

空间的问题解决了,然后考虑时间,因为只有全局的线段树,他是所有节点共享的,做答案统计的时候要确保使用的时候数据是对的。

K - Tree  2019icpc南昌K题 (树上启发式合并 dsu on tree)_子树_03

先来说明一下为什么会有数据对不对的问题,递归地往下跑,假如先跑到2号节点,把2号节点信息更新到线段树里面,然后递归回去跑3号节点子树。

当统计3号节点为根的答案时,假设已经跑完了以5号节点为根的子树,现在要计算4号节点的贡献,按照上面的思路,就应该找5号节点为根子树,查某个权值某个区间有多少个点。

关键的地方来了,因为2号节点的数据已经更新到线段树里了,如果不做处理直接查,那就会出问题,查的信息都不对了。

如果暴力的解决这个问题,就是每次使用线段树的时候,都先清空,然后跑对应的子树每个节点,更新到线段树上,最后再查询。

这么做显然n方,时间不允许,重点又来了,这里就正式开始介绍树上启发式合并了,他可以把时间优化到n*logn。(确实啰嗦了点,但为了照顾像我一样的小白,就决定说得仔细一些......)

 

注意到,当统计以3号节点为根节点的答案时,除了他子树包含的点,其他点都毫无用处,即2号节点的信息此时不应该出现在线段树里。

那么我们把2号节点的信息删掉不就行?当我统计完3号节点的答案后,我再把2号节点的信息加回线段树里面,这不就完美了。

然后再看看时间复杂度,如果先统计2号节点为根的子树,再统计3号节点为根的子树,那么操作就是:

跑2号节点为根的子树,期间把子树每个节点都更新到线段树上,跑完后在线段树上删除子树的每个节点(为了消除都其他子树的影响),然后跑3号节点为根的子树,期间把子树每个节点都更新到线段树上,跑完后发现1号节点的子树都跑完了,结束递归。那么发现,3号节点的子树信息就不需要删除了,也就是说,如果一个节点i有n个子树,可以选择一个子树只跑一次(加信息),其他子树都要跑三次(加信息(统计答案)+删信息(消除对其他子树的影响)+加信息(维护i节点子树的信息,递归出去要给其他节点用))

重点又双叒叕来了!

根据上面的分析,显然要选一棵最大的子树最后跑,会使得时间最优,这就是树上启发式合并的关键思想,并且这样就可以使得时间变为nlogn。

实际上就是先跑轻儿子及其子树,再跑重儿子及其子树。(重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点,轻儿子就是除重儿子之外的其他节点,树链剖分的内容)

 

树上启发式合并就学完了!除了换了一下遍历儿子的顺序,省了一次消影响(重儿子的影响不用减了),好像与暴力没其它区别了!!


K - Tree  2019icpc南昌K题 (树上启发式合并 dsu on tree)_子树_04

 

 代码:



#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 7;
int sz[maxn],val[maxn],dep[maxn],son[maxn];
int T[maxn],ls[maxn*200],rs[maxn*200],tr[maxn*200],cnt,n,k;
ll ans;
vector<int>E[maxn];

void dfs1(int u) {//预处理每个节点的大小sz,深度dep和每个节点的重儿子son
sz[u] = 1;
for (auto v:E[u]) {
dep[v] = dep[u] + 1;
dfs1(v);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void update(int &rt,int l,int r,int pos,int c) {
if(!rt) rt = ++cnt; // 注意,动态开点
tr[rt] += c;
if(l == r) return ;
int mid = l + r >> 1;
if(pos<=mid) update(ls[rt],l,mid,pos,c);
if(mid<pos) update(rs[rt],mid+1,r,pos,c);
}
ll query(int rt,int l,int r,int L,int R) {
if(!rt) return 0;
if(L<=l && r<=R) return tr[rt];
int mid = l + r >> 1;
ll ans = 0;
if(L<=mid) ans += query(ls[rt],l,mid,L,R);
if(mid<R) ans += query(rs[rt],mid+1,r,L,R);
return ans;
}
void add(int u) {
update(T[val[u]],1,n,dep[u],1);
for (auto v:E[u]) add(v);
}
void del(int u) {//删除u节点及其子树在线段树上的信息
update(T[val[u]],1,n,dep[u],-1);
for (auto v:E[u]) del(v);
}
void gao(int u,int fa) {
int d = k + 2 * dep[fa] - dep[u];//最大深度
int w = 2 * val[fa] - val[u];//另一个点的点权
d = min(d,n);
if(w >= 0 && w <= n) ans += query(T[w],1,n,1,d);
for (auto v:E[u]) gao(v,fa);//子树的每个点都要暴力统计
}
void dfs2(int u) { // 树上启发式合并
for (auto v:E[u]) { // 1.先跑轻儿子及其子树,跑完后暴力删除
if(v == son[u]) continue;
dfs2(v);//跑轻儿子v及其子树
del(v);//删除轻儿子v及其子树
}
if(son[u]) dfs2(son[u]);//2.跑重儿子,不删
for (auto v:E[u]) {//3.把所有轻儿子都加回来
if(v == son[u]) continue;
gao(v,u);//统计答案 (以u为根节点,其中一个点在v及其 子树)
add(v);
}
update(T[val[u]],1,n,dep[u],1);//把自己也加上
}

int main() {
int x;
scanf("%d%d",&n,&k);
for (int i=1; i<=n; ++i) {
scanf("%d",&val[i]); //每个点的点权val
}
for (int i=2; i<=n; ++i) {
scanf("%d",&x);
E[x].push_back(i);
}
dep[1] = 1;
dfs1(1);
dfs2(1);//关键看这里
printf("%lld",ans * 2);
return 0;
}