题目链接
题目大意
给定两个长 \(n\) 的排列 \(\{p_i\}\) 和 \(\{q_i\}\),求有多少个长 \(n\) 的排列 \(\{r_i\}\),满足 \(\forall 1\leq i\leq n,\;r_i\neq p_i,r_i\neq q_i\) 。
\(1\leq n\leq 3000\)
思路
此题是错排问题的升级版,错排 \(f(n)=n!\sum_{i=0}^n (-1)^i\frac{1}{i!}\),是通过分别钦定有 \(i\) 位是 \(p_i=i\) 然后容斥出来的,这里也采用类似的思路,不过此时对于 \(p_i=q_j\),我们不能钦定 \(r_i=p_i,r_j=q_j\),注意到很多对互斥关系,于是考虑建图。
对于两排列中 \(p_i=q_j\) 的地方,我们在 \(i\) 和 \(j\) 之间连一条边,注意到所有点的度数都为 \(2\),所以这些边够成了若干个环和孤立点,而对于要钦定相等的地方,相当于 \(i\) 点要选一条相连的边与其对应,且任意两点不能对应相同边。设 \(f(i,j)\) 表示在长度为 \(i\) 的环上选 \(j\) 个点来钦定相等的方案数,那么钦定 \(k\) 个位置不错排的方案数可以利用 \(f\) 数组做一遍背包得出来。
考虑计算 \(f(i,j)\),这个东西可以用 \(dp\) 相对容易地计算出来,但是在错排两个序列(每个点度数为 \(2\))的时候,\(f\) 是有简洁的组合数表达式的。我们先把这个环定个方向,对于每个被钦定的点,若选择的是逆时针方向的边,则染为蓝色,若选择的是顺时针方向的,则染为红色,另外把未被钦定的点染成黑色。
首先显然 \(f_{i,i}=2\),接下来考虑未把整个环全部钦定的情况。我们先把环黑白染色,白点以后可以是蓝或红点,设目前有 \(x\) 个黑点,\(y\) 个白点,注意到若一个点被涂成蓝色,那么它逆时针方向的相邻点就不能被涂成红色,也就是说,每个白点组成的连续段必然是前面一段蓝色,后面一段红色的形式,染色的方案数是(段的长度 \(+1\)),总方案数即为所有段长\(+1\) 的乘积。
在组合意义下,这个 \(+1\) 很麻烦,于是我们引入绿点来表示红蓝点的交界处(而不是隔板),于是现在的环变成了 黑点-绿点-黑点 交替的形式,为了方便计数,找到环上序号最小的点,然后从它开始断环为链,那么黑绿交替就出现了两种情况:
- \(B-G-B-...-G\):这种没什么忌讳,直接染色即可,方案数为 \(\binom{2x+y}{y}\)
- \(G-B-G-...-B\):在断环为链的时候,选的是序号最小的点,绿点不予考虑,那么这里第一个点就不能是绿点,于是必须是白点,方案数为 \(\binom{2x+y-1}{y-1}\)
\(x+y=i,y=j\),从而 \(f(i,j)=\binom{2i-j}{j}+\binom{2i-j-1}{j-1}\) 。
于是就做完了,时间复杂度 \(O(n^2)\) 。
对于多个排列限制的问题也可以采用这样的建图转化,但想了想好像不大会做,会不会是 \(NPC\) 啊。
Code
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 3030
#define ll long long
#define mod 1000000007
using namespace std;
int p[N], q[N], loc[N][2];
int fa[N], siz[N];
int n;
ll C[2*N][2*N], fact[N], dp[2][N];
int find(int x){ return fa[x] == x ? x : (fa[x] = find(fa[x])); }
void init(){
C[0][0] = 1;
rep(i,1,2*n){
C[i][0] = 1;
rep(j,1,i) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
}
fact[0] = 1;
rep(i,1,n) fact[i] = (fact[i-1]*i) % mod;
}
int main(){
cin>>n;
rep(i,1,n) cin>>p[i], loc[p[i]][0] = i;
rep(i,1,n) cin>>q[i], loc[q[i]][1] = i;
rep(i,1,n) fa[i] = i, siz[i] = 1;
rep(i,1,n){
int x = find(loc[i][0]), y = find(loc[i][1]);
if(x == y) continue;
fa[x] = y, siz[y] += siz[x];
}
init();
int par = 0;
dp[0][0] = 1;
rep(x,1,n) if(fa[x] == x){
if(siz[x] == 1){
rep(i,0,n) if(dp[par][i]) rep(j,0,1)
(dp[par^1][i+j] += dp[par][i]) %= mod;
} else{
int sz = siz[x];
rep(i,0,n) if(dp[par][i]){
(dp[par^1][i+sz] += 2*dp[par][i]) %= mod;
(dp[par^1][i] += dp[par][i]) %= mod;
rep(j,1,sz-1) (dp[par^1][i+j] += (C[2*sz-j][j]+C[2*sz-j-1][j-1]) * dp[par][i]) %= mod;
}
}
rep(i,0,n) dp[par][i] = 0;
par ^= 1;
}
ll ans = 0;
rep(i,0,n) (ans += (i&1 ? -1 : 1) * fact[n-i] * dp[par][i]) %= mod;
if(ans < 0) ans += mod;
cout<< ans <<endl;
return 0;
}