根据LGV引理可以得到一个全是组合数的行列式,但n太大不能直接高斯消元。考场上尝试了好久直接把行列式化简,想要\(O(n)\)计算,最终无果。

赛后发现每一列提出一个阶乘的分母之后,每一行再提出一个\((a_i+1)\),就可以得到一个范德蒙德行列式。范德蒙德行列式的公式中的每一项\((a_i-a_j)\)可以整到多项式的次数上,化为多项式乘法的形式,用FFT快速计算。最终可以\(O(nlogn)\)求解。

属实是线代拉跨

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 3e6 + 7, md = 998244353;
const double PI = acos(-1.0);
int Bit, Lim;
struct Complex {
    double x, y;
    Complex() {x = y = 0;}
    Complex(double _x, double _y) {x = _x, y = _y;}
}A[maxn], B[maxn];
Complex operator + (Complex a, Complex b) {return Complex(a.x + b.x, a.y + b.y);}
Complex operator - (Complex a, Complex b) {return (Complex){a.x - b.x, a.y - b.y};}
Complex operator * (Complex a, Complex b) {return (Complex){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};}
Complex conj(Complex a) {return (Complex){a.x, -a.y};}
void fft(Complex *a, int n, int f) {
    for (int i = 0, j = 0; i < n; i++) {
        if (i > j) swap(a[i], a[j]);
        for (int l = n >> 1; (j ^= l) < l; l >>= 1);
    }
    for (int i = 2; i <= n; i <<= 1) {
        int m = i >> 1;
        Complex wn = Complex(cos(2*PI/i), sin(2*PI*f/i)), t, w, u;
        for (int k = 0; k < n; k += i) {
            w = Complex(1, 0);
            for (int j = 0; j < m; j++) {
                t = w * a[j + k + m];
                u = a[j + k];
                a[j + k] = u + t;
                a[j + k + m] = u - t;
                w = w * wn;
            }
        }
    }
    if (f == -1) {
        for (int i = 0; i < n; i++) a[i].x /= n;
    }
}
int rd() {
    int s = 0, f = 1; char c = getchar();
    while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
    while (c >= '0' && c <= '9') {s = s * 10 + c - '0'; c = getchar();}
    return s * f;
}
int n;
ll ksm(ll a, int b) {
    //puts("(");
    ll res = 1;
    while (b) {
        if (b & 1) res = res * a % md;
        a = a * a % md;
        b >>= 1;
    }
    //puts(")");
    return res;
}
ll a[maxn], mul = 1;
ll fac[maxn], finv = 1;

int main() {
    n = rd();
    for (Lim = 1; Lim < (1 + 2000000); Lim <<= 1);
    fac[0] = 1;
    for (int i = 0; i < n; i++) a[i] = rd(), mul = mul * (a[i]+1) % md;
    for (int i = 1; i <= n; i++) fac[i] = fac[i-1] * i % md, finv = finv * fac[i] % md;
    mul = mul * ksm(finv, md - 2) % md;
    for (int i = 0; i < n; i++) {
        A[a[i]].x += 1;
        B[1000000-a[i]].x += 1;
    }
    fft(A, Lim, 1);
    fft(B, Lim, 1);
    for (int i = 0; i < Lim; i++) A[i] = A[i] * B[i]; 
    fft(A, Lim, -1);
    for (int i = 1000001; i < Lim; i++) {
        mul = mul * ksm(i-1000000, (int)floor(A[i].x + 0.5)) % md;
    }
    printf("%lld\n", mul);
    return 0;
}