题目链接

ABC214 G - Three Permutations

题目大意

给定两个长 \(n\) 的排列 \(\{p_i\}\)\(\{q_i\}\),求有多少个长 \(n\) 的排列 \(\{r_i\}\),满足 \(\forall 1\leq i\leq n,\;r_i\neq p_i,r_i\neq q_i\)

\(1\leq n\leq 3000\)

思路

此题是错排问题的升级版,错排 \(f(n)=n!\sum_{i=0}^n (-1)^i\frac{1}{i!}\)​,是通过分别钦定有 \(i\)​ 位是 \(p_i=i\)​​​ 然后容斥出来的​,这里也采用类似的思路,不过此时对于 \(p_i=q_j\),我们不能钦定 \(r_i=p_i,r_j=q_j\),注意到很多对互斥关系,于是考虑建图。

对于两排列中 \(p_i=q_j\) 的地方,我们在 \(i\)\(j\) 之间连一条边,注意到所有点的度数都为 \(2\),所以这些边够成了若干个环和孤立点,而对于要钦定相等的地方,相当于 \(i\) 点要选一条相连的边与其对应,且任意两点不能对应相同边。设 \(f(i,j)\) 表示在长度为 \(i\) 的环上选 \(j\) 个点来钦定相等的方案数,那么钦定 \(k\) 个位置不错排的方案数可以利用 \(f\) 数组做一遍背包得出来。

考虑计算 \(f(i,j)\),这个东西可以用 \(dp\) 相对容易地计算出来,但是在错排两个序列(每个点度数为 \(2\))的时候,\(f\) 是有简洁的组合数表达式的。我们先把这个环定个方向,对于每个被钦定的点,若选择的是逆时针方向的边,则染为蓝色,若选择的是顺时针方向的,则染为红色,另外把未被钦定的点染成黑色。

首先显然 \(f_{i,i}=2\)​​,接下来考虑未把整个环全部钦定的情况。我们先把环黑白染色,白点以后可以是蓝或红点,设目前有 \(x\) 个黑点,\(y\)​ 个白点,注意到若一个点被涂成蓝色,那么它逆时针方向的相邻点就不能被涂成红色,也就是说,每个白点组成的连续段必然是前面一段蓝色,后面一段红色的形式,染色的方案数是(段的长度 \(+1\)​),总方案数即为所有段长\(+1\) 的乘积。

在组合意义下,这个 \(+1\)​​​​ 很麻烦,于是我们引入绿点来表示红蓝点的交界处(而不是隔板),于是现在的环变成了 黑点-绿点-黑点 交替的形式,为了方便计数,找到环上序号最小的点,然后从它开始断环为链,那么黑绿交替就出现了两种情况:

  1. \(B-G-B-...-G\):这种没什么忌讳,直接染色即可,方案数为 \(\binom{2x+y}{y}\)
  2. \(G-B-G-...-B\):在断环为链的时候,选的是序号最小的点,绿点不予考虑,那么这里第一个点就不能是绿点,于是必须是白点,方案数为 \(\binom{2x+y-1}{y-1}\)

\(x+y=i,y=j\),从而 \(f(i,j)=\binom{2i-j}{j}+\binom{2i-j-1}{j-1}\)

于是就做完了,时间复杂度 \(O(n^2)\)​​ 。

对于多个排列限制的问题也可以采用这样的建图转化,但想了想好像不大会做,会不会是 \(NPC\) 啊。

Code

#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 3030
#define ll long long
#define mod 1000000007
using namespace std;

int p[N], q[N], loc[N][2];
int fa[N], siz[N];
int n;

ll C[2*N][2*N], fact[N], dp[2][N];

int find(int x){ return fa[x] == x ? x : (fa[x] = find(fa[x])); }

void init(){
    C[0][0] = 1;
    rep(i,1,2*n){
        C[i][0] = 1;
        rep(j,1,i) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
    }
    fact[0] = 1;
    rep(i,1,n) fact[i] = (fact[i-1]*i) % mod;
}

int main(){
    cin>>n;
    rep(i,1,n) cin>>p[i], loc[p[i]][0] = i;
    rep(i,1,n) cin>>q[i], loc[q[i]][1] = i;

    rep(i,1,n) fa[i] = i, siz[i] = 1;
    rep(i,1,n){
        int x = find(loc[i][0]), y = find(loc[i][1]);
        if(x == y) continue;
        fa[x] = y, siz[y] += siz[x];
    }

    init();
    int par = 0;
    dp[0][0] = 1;
    rep(x,1,n) if(fa[x] == x){
        if(siz[x] == 1){
            rep(i,0,n) if(dp[par][i]) rep(j,0,1)
                (dp[par^1][i+j] += dp[par][i]) %= mod;
        } else{
            int sz = siz[x];
            rep(i,0,n) if(dp[par][i]){
                (dp[par^1][i+sz] += 2*dp[par][i]) %= mod;
                (dp[par^1][i] += dp[par][i]) %= mod;
                rep(j,1,sz-1) (dp[par^1][i+j] += (C[2*sz-j][j]+C[2*sz-j-1][j-1]) * dp[par][i]) %= mod;
            }
        }
        rep(i,0,n) dp[par][i] = 0;
        par ^= 1;
    }

    ll ans = 0;
    rep(i,0,n) (ans += (i&1 ? -1 : 1) * fact[n-i] * dp[par][i]) %= mod;
    if(ans < 0) ans += mod;
    cout<< ans <<endl;
    return 0;
}