LOJ 6485. LJJ 学二项式定理

由于\(a\)的长度很短,考虑枚举\(a_i\),然后算他的贡献。

\(k=|a|=4\)

\[Answer=\sum_{i=0}^{k-1}a_i\sum_{j=0}^{n}[k|j-i]{n\choose j}s^{j} \]

很自然想到单位根反演:

\[Answer=\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{j=0}^{n}(\sum_{z=0}^{k-1} w_{k}^{(j-i)\times z}){n\choose j}s^{j} \]

交换求和符号:

\[Answer=\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}\sum_{j=0}^{n}(w_{n}^{z})^j{n\choose j}s^{j} \]

后面的那个很想二项式定理:

\(x^k=x^nx^{k-n}=x^n(\frac{1}{x})^{n-k}\),得到:

\[Answer=s^n\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}\sum_{j=0}^{n}(w_{n}^{z})^j{n\choose j}(\frac{1}{s})^{n-j}\\ =s^n\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}(w_{n}^z+\frac{1}{s})^n \]

然后就做完了。

时间复杂度\(O(Tk^2\log n)\)​,不过\(\log n\)可能可以优化掉。

#include<bits/stdc++.h>
#define rb(a,b,c) for(int a=b;a<=c;++a)
#define rl(a,b,c) for(int a=b;a>=c;--a)
#define LL long long
#define IT iterator
#define PB push_back
#define II(a,b) make_pair(a,b)
#define FIR first
#define SEC second
#define FREO freopen("check.out","w",stdout)
#define rep(a,b) for(int a=0;a<b;++a)
#define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
#define random(a) rng()%a
#define ALL(a) a.begin(),a.end()
#define POB pop_back
#define ff fflush(stdout)
#define fastio ios::sync_with_stdio(false)
#define check_min(a,b) a=min(a,b)
#define check_max(a,b) a=max(a,b)
using namespace std;
//inline int read(){
//    int x=0;
//    char ch=getchar();
//    while(ch<'0'||ch>'9'){
//        ch=getchar();
//    }
//    while(ch>='0'&&ch<='9'){
//        x=(x<<1)+(x<<3)+(ch^48);
//        ch=getchar();
//    }
//    return x;
//}
const int INF=0x3f3f3f3f;
typedef pair<int,int> mp;
/*}
*/
const int MOD=998244353;
const int G=3;
int quick(int A,int B){
	int res=1;
	while(B){
		if(B&1) res=1ll*res*A%MOD;
		B>>=1;
		A=1ll*A*A%MOD;
	}
	return res;
}
int inv(int A){
	return quick(A,MOD-2);
}
int w[4];
void add(int & A,int B){
	A+=B;
	if(A>=MOD) A-=MOD;
}
void solve(){
	LL n;
	int s,a[4];
	scanf("%lld%d",&n,&s);
	n%=MOD-1;
	rep(i,4) scanf("%d",&a[i]);
	int ans=0;
	rep(j,4){
		int tmp=0;
		rep(k,4){
			add(tmp,1ll*quick((w[k]+inv(s))%MOD,n)*inv(w[j*k%4])%MOD);
		}
		tmp=1ll*tmp*quick(s,n)%MOD;
		add(ans,1ll*tmp*a[j]%MOD);
	}
	ans=1ll*ans*inv(4)%MOD;
	printf("%d\n",ans);
}
int main(){
	w[0]=1;
	w[1]=quick(G,(MOD-1)/4);
	w[2]=1ll*w[1]*w[1]%MOD;
	w[3]=1ll*w[2]*w[1]%MOD;
	int T;
	scanf("%d",&T);
	while(T--) solve();
	return 0;
}