昊哥从牛客搬的,懒得找原题了
题意就是多组询问,每次询问一条树上路径,将这条路径上的点拿下来做\(0/1\)背包,求使得点权和为\(K\)的倍数的方案有几种
\(n<=200000,K<=50,Q<=500000\)
首先这确实是一个背包,我们可以直接用树剖和线段树来维护这些路径,线段树上每个节点存一个数组\(dp[i][j]\),表示\(i\)这个区间选择出的数\(mod\ K=j\)的方案数
之后发现我们每次合并都是一个卷积,于是复杂度\(O(Qk^2logn)\),可以用\(NTT\)优化到\(O(Qklognlogk)\),但是并没有什么用
正解点分治,我们把询问离线,处理好每一组询问在那一个分治中心被处理到
处理当前分治重心的时候,我们直接求出每一个点到分治重心的\(dp\)数组,之后合并答案,由于这个时候我们只需要求\(dp[0]\),所以合并答案\(O(k)\)时间内就能完成
代码
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define LL long long
#define re register
#define inf 999999999
#define maxn 500005
const LL mod=998244353;
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
struct Ask{int x,y,l,rk;}q[maxn];
std::vector<int> v[maxn],t[maxn];
int sum[maxn],mx[maxn],vis[maxn],col[maxn];
int head[maxn],dfn[maxn],st[maxn],Ans[maxn],a[maxn];
int n,m,num,S,now,rt,R,__,Top,K;
inline int cmp(Ask A,Ask B) {return dfn[A.l]<dfn[B.l];}
inline void add(int x,int y) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;}
LL dp[2][maxn][50];
void getroot(int x,int fa) {
sum[x]=1,mx[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getroot(e[i].v,x);sum[x]+=sum[e[i].v];
if(sum[e[i].v]>mx[x]) mx[x]=sum[e[i].v];
}
mx[x]=max(mx[x],S-sum[x]);
if(mx[x]<now) now=mx[x],rt=x;
}
void paint(int x,int fa,int c,int now) {
col[x]=c;st[++Top]=x;
for(re int i=0;i<t[x].size();i++) {
if(col[q[t[x][i]].x]&&col[q[t[x][i]].x]!=c) q[t[x][i]].l=now;
if(col[q[t[x][i]].y]&&col[q[t[x][i]].y]!=c) q[t[x][i]].l=now;
}
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||fa==e[i].v) continue;
paint(e[i].v,x,c,now);
}
}
void rebuild(int x) {
vis[x]=1;dfn[x]=++__;
int cnt=1;Top=0;col[x]=1;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
cnt++;paint(e[i].v,0,cnt,x);
}
while(Top) col[st[Top--]]=0;col[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
S=sum[e[i].v],now=inf,getroot(e[i].v,0);
v[x].push_back(rt),rebuild(rt);
}
}
void getdis(int x,int fa,int o) {
if(o) st[++Top]=x;
for(re int i=0;i<K;i++)
dp[o][x][i]=dp[o][fa][i];
for(re int i=0;i<K;i++)
dp[o][x][(i+a[x])%K]+=dp[o][fa][i],dp[o][x][(i+a[x])%K]%=mod;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||fa==e[i].v) continue;
getdis(e[i].v,x,o);
}
}
inline void clear(int x) {
memset(dp[0][x],0,sizeof(dp[0][x]));
memset(dp[1][x],0,sizeof(dp[1][x]));
}
void dfs(int x) {
vis[x]=1;
Top=0;st[++Top]=x;dp[1][x][0]=1;dp[1][x][a[x]]++,dp[0][x][0]=1;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
getdis(e[i].v,x,0);getdis(e[i].v,x,1);
}
while(q[now].l==x&&now<=m) {
LL ans=0;int ls=q[now].x,rs=q[now].y;
for(re int i=0;i<K;i++)
ans+=(dp[1][ls][i]*dp[0][rs][(K-i)%K]%mod),ans%=mod;
Ans[q[now].rk]=ans,now++;
}
while(Top) clear(st[Top--]);
for(re int i=0;i<v[x].size();i++) dfs(v[x][i]);
}
signed main() {
n=read(),K=read();
for(re int x,y,i=1;i<n;i++) x=read(),y=read(),add(x,y),add(y,x);
for(re int i=1;i<=n;i++) a[i]=read()%K;m=read();
for(re int i=1;i<=m;i++)
q[i].x=read(),q[i].y=read(),q[i].rk=i,t[q[i].x].push_back(i),t[q[i].y].push_back(i);
for(re int i=1;i<=m;i++) if(!q[i].l) q[i].l=q[i].x;
S=n,now=inf,getroot(1,0);R=rt;rebuild(rt);
std::sort(q+1,q+m+1,cmp);
memset(vis,0,sizeof(vis));now=1;dfs(R);
for(re int i=1;i<=m;i++) printf("%d\n",Ans[i]);
return 0;
}