tag:树形dp,组合计数


首先根据递归关系建出一个树,然后就变为了树上问题:对树染色,满足任意一个点到根的 \(num_r\le c_r,num_b\le c_b\),求所有染色方案的 \(num_rnum_b^2\)

于是想到一个 dp,设 \(f[i][j][k]\),表示点 \(i\) 的祖先选了 \(j\)\(b\),点 \(i\) 子树选了 \(k\)\(b\)。这样的好处在于对于点 \(i\) 的祖先 \(num_r+num_b=dep_i\),就省掉了一维状态。

然而这个是 \(O(n^3)\),考虑优化到 \(O(n^2)\)

这里就要看答案柿子的组合意义 \(num_rnum_b^2\) 相当于可重复按顺序选出\(1\)\(r\)\(2\)\(b\)

所以实际上 \(k\) 那一维可以改为 \(f[i][j][0/1][0/1][0/1]\),然后就可以 \(O(n^2)\) 了。

#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=false;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
	for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
	if(flag)n=-n;
}

enum{
    MAXN = 2505,
    MOD = 998244353
};

inline int ksm(int base, int k=MOD-2){
    int res=1;
    while(k){
        if(k&1)
            res = 1ll*res*base%MOD;
        base = 1ll*base*base%MOD;
        k >>= 1;
    }
    return res;
}

inline int inc(int a, int b){
    a += b;
    if(a>=MOD) a -= MOD;
    return a;
}

inline int dec(int a, int b){
    a -= b;
    if(a<0) a += MOD;
    return a;
}

inline void iinc(int &a, int b){a = inc(a,b);}
inline void ddec(int &a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}

struct _{
    int nxt, to;
    _(int nxt=0, int to=0):nxt(nxt),to(to){}
}edge[MAXN];
int fst[MAXN], tot;

inline void Add_Edge(int f, int t){edge[++tot] = _(fst[f], t); fst[f] = tot;}

char opt[MAXN<<1][5];
int q[MAXN], top;

int n, cr, cb;
int ans[MAXN], f[2][2][2][MAXN][MAXN], dep[MAXN], g[MAXN][2][2][2][2], tmp[2][2][2][2];
void dp(int x){
    if(x==1){
        f[0][0][0][1][0] = 1;
        for(int u=fst[x]; u; u=edge[u].nxt){
            int v=edge[u].to;
            dep[v] = 1;
            dp(v);
            for(int a=1; ~a; a--) for(int b=1; ~b; b--) for(int c=1; ~c; c--) if(f[a][b][c][x][0])
                for(int aa=0; aa<2-a; aa++) for(int bb=0; bb<2-b; bb++) for(int cc=0; cc<2-c; cc++) if(f[aa][bb][cc][v][0])
                    upd(tmp[a+aa][b+bb][c+cc][0],1ll*f[a][b][c][x][0]*f[aa][bb][cc][v][0]);
            for(int a=0; a<2; a++) for(int b=0; b<2; b++) for(int c=0; c<2; c++) f[a][b][c][x][0] = tmp[a][b][c][0], tmp[a][b][c][0] = 0;
        }
        return;
    }
    for(int i=0; i<dep[x]; i++) if(i<=cb and dep[x]-i-1<=cr){
        memset(g[x],0,sizeof g[x]);
        if(i<cb) g[x][0][0][0][1] = g[x][0][1][0][1] = g[x][0][0][1][1] = g[x][0][1][1][1] = 1;
        if(dep[x]-i<=cr) g[x][0][0][0][0] = g[x][1][0][0][0] = 1;
        for(int u=fst[x]; u; u=edge[u].nxt){
            int v=edge[u].to;
            if(!dep[v]){
                dep[v] = dep[x]+1;
                dp(v);
            }
            for(int a=1; ~a; a--) for(int b=1; ~b; b--) for(int c=1; ~c; c--) for(int d=0; d<2; d++) if(g[x][a][b][c][d])
                for(int aa=1-a; ~aa; aa--) for(int bb=1-b; ~bb; bb--) for(int cc=1-c; ~cc; cc--) if(f[aa][bb][cc][v][i+d])
                    upd(tmp[a+aa][b+bb][c+cc][d],1ll*g[x][a][b][c][d]*f[aa][bb][cc][v][i+d]);
            for(int a=0; a<2; a++) for(int b=0; b<2; b++) for(int c=0; c<2; c++) for(int d=0; d<2; d++) g[x][a][b][c][d] = tmp[a][b][c][d], tmp[a][b][c][d] = 0;
        }
        for(int a=0; a<2; a++) for(int b=0; b<2; b++) for(int c=0; c<2; c++){
            // if(i<cb) iinc(f[a][b][c][x][i],g[x][a][b][c][1]);
            // if(dep[x]-i<cr) iinc(f[a][b][c][x][i],g[x][a][b][c][0]);
            iinc(f[a][b][c][x][i],g[x][a][b][c][1]);
            iinc(f[a][b][c][x][i],g[x][a][b][c][0]);
        }
    }
    // for(int a=0; a<2; a++) for(int b=0; b<2; b++) for(int c=0; c<2; c++) for(int i=0; i<dep[x]; i++) if(f[a][b][c][x][i])
    //     printf("f[%d][%d][%d][%d][%d] = %d\n",a,b,c,x,i,f[a][b][c][x][i]);
}
~~~~
int main(){
    // freopen("ex_stack5.in","r",stdin);
    // tt = clock();
    Read(n); Read(cr); Read(cb);
    for(int i=1; i<=2*n; i++) scanf("%s",opt[i]);
    int node_cnt=1; q[top=1] = 1;
    for(int i=1; i<=2*n; i++)
        if(opt[i][1]=='u'){
            node_cnt++;
            Add_Edge(q[top],node_cnt);
            q[++top] = node_cnt;
        }
        else top--, assert(top>=0);
    dp(1);
    cout<<f[1][1][1][1][0]<<'\n';
    return 0;
}