线段树合并,字面意思,可以将两个线段树合并到一起。如果我们的dp需要将两个数组相加或相乘,亦或是一些其他的操作,那么我们可以将这两个数组上建的线段树合并到一起去,可以加速我们dp的速度。

线段树合并的复杂度是 \(nlogn\) 的感觉挺玄学,但确实能证明。大概是因为我们对线段树进行操作后一共只会产生 \(nlogn\) 个点,而每进行一次操作都会将点数减一,所以复杂度可证。

思考一道例题,P4556

考虑这些链上的加操作,我们可以使用树上差分,只修改u,v和lca,fa[lca]来进行修改操作。然后我们的每一个点将它的儿子的值加到它自己身上。

现在得到了优秀的 \(n^2\) 算法,考虑使用线段树合并来进行儿子往父亲的合并。可以将复杂度优化到 \(nlogn\) 。

那么如何线段树合并呢?

int merge(int a,int b)
{
  if(!a || !b) return a|b;
  ls[a] = merge(ls[a],ls[b]);
  rs[a] = merge(rs[a],rs[b]);
  push_up;
}

大概如此,更多操作需要根据题意进行。

下面是此题的完整代码:

#include#include#include#define mid (l+r>>1)

using namespace std;

int read()
{
	int a = 0,x = 1;char ch = getchar();
	while(ch > '9' || ch < '0') {if(ch == '-') x = -1;ch = getchar();}
	while(ch >= '0' && ch <= '9') {a = a*10 + ch-'0';ch = getchar();}
	return a*x;
}
const int N=5e6+7,R=1e5;
int n,m;
int head[N],go[N],nxt[N],cnt,ans[N];
void add(int u,int v)
{
	go[++cnt] = v;
	nxt[cnt] = head[u];
	head[u] = cnt;
}
int fa[N][31],dep[N],rt[N],tot,ls[N],rs[N],tre[N],siz[N];
void dfs1(int u)
{
	dep[u] = dep[fa[u][0]] + 1;
	for(int i = 1;i <= 30;i ++) fa[u][i] = fa[fa[u][i-1]][i-1];
	for(int e = head[u];e;e = nxt[e]) {
		int v = go[e];
		if(v == fa[u][0]) continue;
		fa[v][0] = u;dfs1(v);
	}
}
int LCA(int a,int b)
{
	if(dep[a] < dep[b]) swap(a,b);
	for(int i = 30;i >= 0;i --) if(dep[fa[a][i]] >= dep[b]) a = fa[a][i];
	if(a == b) return a;
	for(int i = 30;i >= 0;i --) if(fa[a][i] != fa[b][i]) a = fa[a][i],b = fa[b][i];
	return fa[a][0];
}

void pushup(int root) {
	if(siz[ls[root]] >= siz[rs[root]]) tre[root] = tre[ls[root]];
	else tre[root] = tre[rs[root]];
	siz[root] = max(siz[ls[root]],siz[rs[root]]);
}

void modify(int &root,int l,int r,int p,int x)
{
	if(!root) root = ++tot;
	if(l == r && l == p) {siz[root] += x,tre[root] = l;return ;}
	if(p <= mid) modify(ls[root],l,mid,p,x);
	else modify(rs[root],mid+1,r,p,x);
	pushup(root);
}

int merge(int a,int b,int l,int r)
{
	if(!a || !b) return a|b;
	if(l == r) {siz[a] += siz[b];return a;}
	ls[a] = merge(ls[a],ls[b],l,mid);
	rs[a] = merge(rs[a],rs[b],mid+1,r);
	pushup(a);return a;
}

void dfs(int u)
{
	for(int e = head[u];e;e = nxt[e]) {
		int v = go[e];if(v == fa[u][0]) continue;
		dfs(v);rt[u] = merge(rt[u],rt[v],1,R);
	}
	ans[u] = siz[rt[u]]?tre[rt[u]]:0;
}

int main()
{
	// freopen("in.in","r",stdin);
	// freopen("out.out","w",stdout);
	n = read(),m = read();
	for(int i = 1;i < n;i ++) {
		int u = read(),v = read();
		add(u,v);add(v,u);
	}
	dfs1(1);
	for(int i = 1;i <= m;i ++) {
		int x = read(),y = read(),z = read(),tmp = LCA(x,y);
		modify(rt[x],1,R,z,1);modify(rt[y],1,R,z,1);
		modify(rt[tmp],1,R,z,-1);if(fa[tmp][0]) modify(rt[fa[tmp][0]],1,R,z,-1);
	}
	dfs(1);
	for(int i = 1;i <= n;i ++) printf("%d\n",ans[i]);
	return 0;
}