问题:
有一些问题,通常见于二维的DP,另一维记录当前x的信息,但是这一维过大无法开下,O(nm)也无法通过。
但是如果发现,对于x,在第二维的一些区间内,取值都是相同的,并且这样的区间是有限个,就可以批量处理。
思想:
通过动态开点线段树维护第二维,
如果某个节点没有儿子,那么这个节点区间都是同一个权值。
也即,一个节点是空节点,那么这个节点所有的值和父亲的值都一致。(其实它的兄弟也是空节点的)
对于序列的问题,
可以直接扫过去,修改某些位置的点。
或者线段树合并。
对于树上的问题,
线段树合并。
实现:
主要考虑什么时候线段树合并停止。以及pushdown的标记问题。
当x都没有儿子或者y都没有儿子时候,整个x的区间或整个y的区间都是同一个值,可以直接计算贡献转移过来(这个必须支持,否则不能整体DP)。
否则,pushdown,进行递归
pushdown时候建立新的儿子(如果之前没有)。
空间复杂度和时间复杂度基本一致。O(nlogn)
只要满足,在x都没有儿子或者y都没有儿子时候,可以快速合并然后return,那么就可以整体DP了。
例题2:
【PKUSC 2019】D2T1 树染色
$dp[x][c]=\Pi (sumy-dp[y][c])$sumy表示y的所有dp[y][*]的和
在x都没有儿子或者y都没有儿子时候,我们要么知道每个x的值,要么知道每个y的值。
在x都没有儿子时候,把y的节点内每个数乘-1再加sumy,再乘上x区间的值。
y都没有儿子时候,直接用(sumy-val)乘给x即可。
code:
#include<bits/stdc++.h> #define reg register int #define il inline #define fi first #define se second #define mk(a,b) make_pair(a,b) #define numb (ch^'0') #define pb push_back #define solid const auto & #define enter cout<<endl #define pii pair<int,int> using namespace std; typedef long long ll; template<class T>il void rd(T &x){ char ch;x=0;bool fl=false;while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true); for(x=numb;isdigit(ch=getchar());x=x*10+numb);(fl==true)&&(x=-x);} template<class T>il void output(T x){if(x/10)output(x/10);putchar(x%10+'0');} template<class T>il void ot(T x){if(x<0) putchar('-'),x=-x;output(x);putchar(' ');} template<class T>il void prt(T a[],int st,int nd){for(reg i=st;i<=nd;++i) ot(a[i]);putchar('\n');} namespace Modulo{ const int mod=998244353; int ad(int x,int y){return (x+y)>=mod?x+y-mod:x+y;} void inc(int &x,int y){x=ad(x,y);} int mul(int x,int y){return (ll)x*y%mod;} void inc2(int &x,int y){x=mul(x,y);} int qm(int x,int y=mod-2){int ret=1;while(y){if(y&1) ret=mul(x,ret);x=mul(x,x);y>>=1;}return ret;} } using namespace Modulo; namespace Miracle{ const int N=2e5+5; int n,m,k; struct node{ int nxt,to; }e[2*N]; int hd[N],cnt; void add(int x,int y){ e[++cnt].nxt=hd[x]; e[cnt].to=y; hd[x]=cnt; } #define mid ((l+r)>>1) struct tr{ int sum,mul,ad; int ls,rs,val; void op(){ cout<<"SUM "<<sum<<" MUL "<<mul<<" AD "<<ad<<endl; } }t[20000000+3]; int tot,S; vector<int>no[N]; int rt[N]; int nc(){ ++tot; t[tot].sum=0;t[tot].mul=1;t[tot].ad=0; t[tot].ls=t[tot].rs=0;t[tot].val=0; return tot; } void tag(int x,int l,int r,int ml,int aa){ // cout<<" tag "<<x<<" l "<<l<<" r "<<r<<" ml "<<ml<<" ad "<<aa<<endl; // t[x].op(); t[x].sum=mul(t[x].sum,ml); t[x].sum=ad(t[x].sum,mul(r-l+1,aa)); t[x].val=ad(mul(t[x].val,ml),aa); t[x].mul=mul(t[x].mul,ml); t[x].ad=ad(mul(t[x].ad,ml),aa); } void pushup(int x){ t[x].sum=ad(t[t[x].ls].sum,t[t[x].rs].sum); } void pushdown(int x,int l,int r){ if(!t[x].ls) t[x].ls=nc(); if(!t[x].rs) t[x].rs=nc(); tag(t[x].ls,l,mid,t[x].mul,t[x].ad); tag(t[x].rs,mid+1,r,t[x].mul,t[x].ad); t[x].mul=1;t[x].ad=0; } void upda(int &x,int l,int r,int p){ // cout<<" pp "<<p<<" x "<<x<<" l "<<l<<" r "<<r<<" sm "<<t[x].sum<<" mul "<<t[x].mul<<" ad "<<t[x].ad<<endl; // cout<<" ls "<<t[x].ls<<" rs "<<t[x].rs<<endl; if(!x) x=nc(); if(l==r){ // cout<<" ss "<<t[x].sum<<endl; t[x].sum=0; t[x].val=0; return; } pushdown(x,l,r); if(p<=mid) upda(t[x].ls,l,mid,p); else upda(t[x].rs,mid+1,r,p); pushup(x); } int merge(int x,int y,int l,int r){ if(!t[x].ls&&!t[x].rs){ swap(x,y); int v=t[y].val; tag(x,l,r,mod-1,S); tag(x,l,r,v,0); }else if(!t[y].ls&&!t[y].rs){ int v=t[y].val; tag(x,l,r,ad(S,mod-v),0); }else{ pushdown(x,l,r);pushdown(y,l,r); t[x].ls=merge(t[x].ls,t[y].ls,l,mid); t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r); pushup(x); } return x;//warining!! } void dfs(int x,int fa){ rt[x]=nc(); tag(rt[x],1,m,1,1); for(reg i=hd[x];i;i=e[i].nxt){ int y=e[i].to; if(y==fa) continue; dfs(y,x); S=t[rt[y]].sum; rt[x]=merge(rt[x],rt[y],1,m); // cout<<y<<" back "<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl; } for(solid c:no[x]){ upda(rt[x],1,m,c); } // cout<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl; } int main(){ rd(n);rd(m);rd(k); int x,y; for(reg i=1;i<n;++i){ rd(x);rd(y); add(x,y);add(y,x); } for(reg i=1;i<=k;++i){ rd(x);rd(y); no[x].push_back(y); } dfs(1,0); printf("%d",t[rt[1]].sum); return 0; } } signed main(){ Miracle::main(); return 0; } /* Author: *Miracle* */