题目

·对于一个排列,手上初始有一个编号为0的物品,从前往后或从后往前遍历这个排列,如果当期位置的编号大于手上的物品,就把当前位置的物品换到手上,那么对于从前往后和从后往前都有一个交换的次数。

·问n有多少排列前往后交换次数为a,后往前交换为b。

·n<=1e5,模998244353.

思路

考虑这个排列的数是怎么分布的
如果两个相邻的前缀最大值i,j,中间的数必然<i,并且中间的数不受限制,只有i不能动
所以一个前缀最大值构成了一个圆排列
由于最后手上的一定是n,那么n的左边就有a-1个前缀最大值,右边有b-1个后缀最大值。
·左边就分成了a-1块,右边b-1块,那么显然只需要把剩下n-1个数分到(a+b-2)个环中,再分到前后就好了。
所以可以用第一类斯特林数解决

代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353,g=3,_g=332748118,maxn=2e5+9;
ll Pow(ll base,ll b){
	ll ret(1);
	while(b){
		if(b&1) ret=1ll*ret*base%mod; base=1ll*base*base%mod; b>>=1;
	}return ret;
}
ll r[maxn],W[maxn];
ll Fir(ll n){
	ll limit(1),len(0),up(n<<1);
	while(limit<up){
		limit<<=1; ++len;
	}
	for(ll i=0;i<limit;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len-1);
	return limit;
}
void NTT(ll *a,ll n,ll type){
	for(ll i=0;i<n;++i) if(i<r[i]) std::swap(a[i],a[r[i]]);
	for(ll mid=1;mid<n;mid<<=1){
		ll wn(Pow(type?g:_g,(mod-1)/(mid<<1)));
		W[0]=1; for(ll i=1;i<mid;++i) W[i]=1ll*W[i-1]*wn%mod;
		for(ll R=mid<<1,j=0;j<n;j+=R)
		    for(ll k=0;k<mid;++k){
		    	ll x(a[j+k]),y(1ll*W[k]*a[j+mid+k]%mod);
		    	a[j+k]=1ll*(x+y)%mod; a[j+mid+k]=1ll*(x-y+mod)%mod;
			}
	}
}
ll T[maxn],F[maxn],H[maxn],fac[maxn],fav[maxn],tmp[maxn],sum[maxn],B[maxn];
ll Mul(ll n,ll *a,ll *b,ll *ans){
	ll limit(Fir(n));
	NTT(a,limit,1); NTT(b,limit,1);
	for(ll i=0;i<limit;++i) ans[i]=1ll*a[i]*b[i]%mod;
	NTT(ans,limit,0);
	for(ll i=((n-1)<<1)+1;i<limit;++i) a[i]=b[i]=0;
	return Pow(limit,mod-2);
}
void Solve(ll n,ll *a){
	if(!n){ a[0]=1; return; }
	if(n==1){ a[1]=1; return; }
	ll len(n/2);
	Solve(len,a);
	for(ll i=0;i<=len;++i){
		F[i]=1ll*Pow(len,i)*fav[i]%mod;
		H[i]=1ll*fac[i]*a[i]%mod;
	}
	std::reverse(H,H+len+1);
	
	ll limit(Fir(len+1));
	NTT(F,limit,1); NTT(H,limit,1);
	for(ll i=0;i<limit;++i) F[i]=1ll*F[i]*H[i]%mod;
	NTT(F,limit,0);
	ll ty(Pow(limit,mod-2));
	for(ll i=0;i<=len;++i) tmp[i]=1ll*F[len-i]*ty%mod*Pow(fac[i],mod-2)%mod;
	for(ll i=(len<<1);i<=limit;++i) F[i]=H[i]=0;
	
	ll val(Mul(len+1,a,tmp,B));
	for(ll i=0;i<=(len<<1);++i) a[i]=1ll*B[i]*val%mod;
	
	if(n&1)
		for(ll i=n;i>=1;--i) a[i]=1ll*(a[i-1]+1ll*(n-1)*a[i]%mod)%mod;
}
ll n,a,b,m;
ll ans[maxn];
int main(){
	scanf("%d%d%d",&n,&a,&b);
	ll val;
	val=fac[0]=fac[1]=1;
	for(ll i=2;i<=n;++i) val=fac[i]=1ll*val*i%mod;
	val=fav[n]=Pow(fac[n],mod-2);
	for(ll i=n;i>=1;--i) val=fav[i-1]=1ll*val*i%mod;
	Solve(n-1,ans);
	
    n=a+b-2; m=a-1;
	printf("%d\n",1ll*ans[n]*fac[n]%mod*fav[m]%mod*fav[n-m]%mod%mod);
}