根据LGV引理可以得到一个全是组合数的行列式,但n太大不能直接高斯消元。考场上尝试了好久直接把行列式化简,想要\(O(n)\)计算,最终无果。
赛后发现每一列提出一个阶乘的分母之后,每一行再提出一个\((a_i+1)\),就可以得到一个范德蒙德行列式。范德蒙德行列式的公式中的每一项\((a_i-a_j)\)可以整到多项式的次数上,化为多项式乘法的形式,用FFT快速计算。最终可以\(O(nlogn)\)求解。
属实是线代拉跨
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 3e6 + 7, md = 998244353;
const double PI = acos(-1.0);
int Bit, Lim;
struct Complex {
double x, y;
Complex() {x = y = 0;}
Complex(double _x, double _y) {x = _x, y = _y;}
}A[maxn], B[maxn];
Complex operator + (Complex a, Complex b) {return Complex(a.x + b.x, a.y + b.y);}
Complex operator - (Complex a, Complex b) {return (Complex){a.x - b.x, a.y - b.y};}
Complex operator * (Complex a, Complex b) {return (Complex){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};}
Complex conj(Complex a) {return (Complex){a.x, -a.y};}
void fft(Complex *a, int n, int f) {
for (int i = 0, j = 0; i < n; i++) {
if (i > j) swap(a[i], a[j]);
for (int l = n >> 1; (j ^= l) < l; l >>= 1);
}
for (int i = 2; i <= n; i <<= 1) {
int m = i >> 1;
Complex wn = Complex(cos(2*PI/i), sin(2*PI*f/i)), t, w, u;
for (int k = 0; k < n; k += i) {
w = Complex(1, 0);
for (int j = 0; j < m; j++) {
t = w * a[j + k + m];
u = a[j + k];
a[j + k] = u + t;
a[j + k + m] = u - t;
w = w * wn;
}
}
}
if (f == -1) {
for (int i = 0; i < n; i++) a[i].x /= n;
}
}
int rd() {
int s = 0, f = 1; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') {s = s * 10 + c - '0'; c = getchar();}
return s * f;
}
int n;
ll ksm(ll a, int b) {
//puts("(");
ll res = 1;
while (b) {
if (b & 1) res = res * a % md;
a = a * a % md;
b >>= 1;
}
//puts(")");
return res;
}
ll a[maxn], mul = 1;
ll fac[maxn], finv = 1;
int main() {
n = rd();
for (Lim = 1; Lim < (1 + 2000000); Lim <<= 1);
fac[0] = 1;
for (int i = 0; i < n; i++) a[i] = rd(), mul = mul * (a[i]+1) % md;
for (int i = 1; i <= n; i++) fac[i] = fac[i-1] * i % md, finv = finv * fac[i] % md;
mul = mul * ksm(finv, md - 2) % md;
for (int i = 0; i < n; i++) {
A[a[i]].x += 1;
B[1000000-a[i]].x += 1;
}
fft(A, Lim, 1);
fft(B, Lim, 1);
for (int i = 0; i < Lim; i++) A[i] = A[i] * B[i];
fft(A, Lim, -1);
for (int i = 1000001; i < Lim; i++) {
mul = mul * ksm(i-1000000, (int)floor(A[i].x + 0.5)) % md;
}
printf("%lld\n", mul);
return 0;
}