线段树合并,字面意思,可以将两个线段树合并到一起。如果我们的dp需要将两个数组相加或相乘,亦或是一些其他的操作,那么我们可以将这两个数组上建的线段树合并到一起去,可以加速我们dp的速度。
线段树合并的复杂度是 \(nlogn\) 的感觉挺玄学,但确实能证明。大概是因为我们对线段树进行操作后一共只会产生 \(nlogn\) 个点,而每进行一次操作都会将点数减一,所以复杂度可证。
思考一道例题,P4556
考虑这些链上的加操作,我们可以使用树上差分,只修改u,v和lca,fa[lca]来进行修改操作。然后我们的每一个点将它的儿子的值加到它自己身上。
现在得到了优秀的 \(n^2\) 算法,考虑使用线段树合并来进行儿子往父亲的合并。可以将复杂度优化到 \(nlogn\) 。
那么如何线段树合并呢?
int merge(int a,int b) { if(!a || !b) return a|b; ls[a] = merge(ls[a],ls[b]); rs[a] = merge(rs[a],rs[b]); push_up; }
大概如此,更多操作需要根据题意进行。
下面是此题的完整代码:
#include#include#include#define mid (l+r>>1) using namespace std; int read() { int a = 0,x = 1;char ch = getchar(); while(ch > '9' || ch < '0') {if(ch == '-') x = -1;ch = getchar();} while(ch >= '0' && ch <= '9') {a = a*10 + ch-'0';ch = getchar();} return a*x; } const int N=5e6+7,R=1e5; int n,m; int head[N],go[N],nxt[N],cnt,ans[N]; void add(int u,int v) { go[++cnt] = v; nxt[cnt] = head[u]; head[u] = cnt; } int fa[N][31],dep[N],rt[N],tot,ls[N],rs[N],tre[N],siz[N]; void dfs1(int u) { dep[u] = dep[fa[u][0]] + 1; for(int i = 1;i <= 30;i ++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int e = head[u];e;e = nxt[e]) { int v = go[e]; if(v == fa[u][0]) continue; fa[v][0] = u;dfs1(v); } } int LCA(int a,int b) { if(dep[a] < dep[b]) swap(a,b); for(int i = 30;i >= 0;i --) if(dep[fa[a][i]] >= dep[b]) a = fa[a][i]; if(a == b) return a; for(int i = 30;i >= 0;i --) if(fa[a][i] != fa[b][i]) a = fa[a][i],b = fa[b][i]; return fa[a][0]; } void pushup(int root) { if(siz[ls[root]] >= siz[rs[root]]) tre[root] = tre[ls[root]]; else tre[root] = tre[rs[root]]; siz[root] = max(siz[ls[root]],siz[rs[root]]); } void modify(int &root,int l,int r,int p,int x) { if(!root) root = ++tot; if(l == r && l == p) {siz[root] += x,tre[root] = l;return ;} if(p <= mid) modify(ls[root],l,mid,p,x); else modify(rs[root],mid+1,r,p,x); pushup(root); } int merge(int a,int b,int l,int r) { if(!a || !b) return a|b; if(l == r) {siz[a] += siz[b];return a;} ls[a] = merge(ls[a],ls[b],l,mid); rs[a] = merge(rs[a],rs[b],mid+1,r); pushup(a);return a; } void dfs(int u) { for(int e = head[u];e;e = nxt[e]) { int v = go[e];if(v == fa[u][0]) continue; dfs(v);rt[u] = merge(rt[u],rt[v],1,R); } ans[u] = siz[rt[u]]?tre[rt[u]]:0; } int main() { // freopen("in.in","r",stdin); // freopen("out.out","w",stdout); n = read(),m = read(); for(int i = 1;i < n;i ++) { int u = read(),v = read(); add(u,v);add(v,u); } dfs1(1); for(int i = 1;i <= m;i ++) { int x = read(),y = read(),z = read(),tmp = LCA(x,y); modify(rt[x],1,R,z,1);modify(rt[y],1,R,z,1); modify(rt[tmp],1,R,z,-1);if(fa[tmp][0]) modify(rt[fa[tmp][0]],1,R,z,-1); } dfs(1); for(int i = 1;i <= n;i ++) printf("%d\n",ans[i]); return 0; }