题目
·对于一个排列,手上初始有一个编号为0的物品,从前往后或从后往前遍历这个排列,如果当期位置的编号大于手上的物品,就把当前位置的物品换到手上,那么对于从前往后和从后往前都有一个交换的次数。
·问n有多少排列前往后交换次数为a,后往前交换为b。
·n<=1e5,模998244353.
思路
考虑这个排列的数是怎么分布的
如果两个相邻的前缀最大值i,j,中间的数必然<i,并且中间的数不受限制,只有i不能动
所以一个前缀最大值构成了一个圆排列
由于最后手上的一定是n,那么n的左边就有a-1个前缀最大值,右边有b-1个后缀最大值。
·左边就分成了a-1块,右边b-1块,那么显然只需要把剩下n-1个数分到(a+b-2)个环中,再分到前后就好了。
所以可以用第一类斯特林数解决
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353,g=3,_g=332748118,maxn=2e5+9;
ll Pow(ll base,ll b){
ll ret(1);
while(b){
if(b&1) ret=1ll*ret*base%mod; base=1ll*base*base%mod; b>>=1;
}return ret;
}
ll r[maxn],W[maxn];
ll Fir(ll n){
ll limit(1),len(0),up(n<<1);
while(limit<up){
limit<<=1; ++len;
}
for(ll i=0;i<limit;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len-1);
return limit;
}
void NTT(ll *a,ll n,ll type){
for(ll i=0;i<n;++i) if(i<r[i]) std::swap(a[i],a[r[i]]);
for(ll mid=1;mid<n;mid<<=1){
ll wn(Pow(type?g:_g,(mod-1)/(mid<<1)));
W[0]=1; for(ll i=1;i<mid;++i) W[i]=1ll*W[i-1]*wn%mod;
for(ll R=mid<<1,j=0;j<n;j+=R)
for(ll k=0;k<mid;++k){
ll x(a[j+k]),y(1ll*W[k]*a[j+mid+k]%mod);
a[j+k]=1ll*(x+y)%mod; a[j+mid+k]=1ll*(x-y+mod)%mod;
}
}
}
ll T[maxn],F[maxn],H[maxn],fac[maxn],fav[maxn],tmp[maxn],sum[maxn],B[maxn];
ll Mul(ll n,ll *a,ll *b,ll *ans){
ll limit(Fir(n));
NTT(a,limit,1); NTT(b,limit,1);
for(ll i=0;i<limit;++i) ans[i]=1ll*a[i]*b[i]%mod;
NTT(ans,limit,0);
for(ll i=((n-1)<<1)+1;i<limit;++i) a[i]=b[i]=0;
return Pow(limit,mod-2);
}
void Solve(ll n,ll *a){
if(!n){ a[0]=1; return; }
if(n==1){ a[1]=1; return; }
ll len(n/2);
Solve(len,a);
for(ll i=0;i<=len;++i){
F[i]=1ll*Pow(len,i)*fav[i]%mod;
H[i]=1ll*fac[i]*a[i]%mod;
}
std::reverse(H,H+len+1);
ll limit(Fir(len+1));
NTT(F,limit,1); NTT(H,limit,1);
for(ll i=0;i<limit;++i) F[i]=1ll*F[i]*H[i]%mod;
NTT(F,limit,0);
ll ty(Pow(limit,mod-2));
for(ll i=0;i<=len;++i) tmp[i]=1ll*F[len-i]*ty%mod*Pow(fac[i],mod-2)%mod;
for(ll i=(len<<1);i<=limit;++i) F[i]=H[i]=0;
ll val(Mul(len+1,a,tmp,B));
for(ll i=0;i<=(len<<1);++i) a[i]=1ll*B[i]*val%mod;
if(n&1)
for(ll i=n;i>=1;--i) a[i]=1ll*(a[i-1]+1ll*(n-1)*a[i]%mod)%mod;
}
ll n,a,b,m;
ll ans[maxn];
int main(){
scanf("%d%d%d",&n,&a,&b);
ll val;
val=fac[0]=fac[1]=1;
for(ll i=2;i<=n;++i) val=fac[i]=1ll*val*i%mod;
val=fav[n]=Pow(fac[n],mod-2);
for(ll i=n;i>=1;--i) val=fav[i-1]=1ll*val*i%mod;
Solve(n-1,ans);
n=a+b-2; m=a-1;
printf("%d\n",1ll*ans[n]*fac[n]%mod*fav[m]%mod*fav[n-m]%mod%mod);
}