2021牛客暑期多校训练营2 J Product of GCDs

思路就是每个数拆成质因数,再从每个质数以及它的次幂去扫一遍所有的数,每次取质数的组合数次幂作为贡献,累乘起来即为最后的结果

记录一下这次调试的过程

开始时考虑用vector去存每一个质数所对应的信息,每次用vector的大小与k进行比较得出结果,得到了最开始的程序

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[40010];
int prime[1000010],tot;
bool bj[10000010];
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1)res=(res+a)%p;
		a=(a+a)%p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
long long phi(long long x)
{
	long long ans=x;
	for(int i=1;i<=tot;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
		}
		if(x<prime[i])break;
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=k;j++)
	C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
vector<int>q[80000];
int main()
{
	getprime(1e7);
	int t;
	scanf("%d",&t);
	while(t--)
	{
		int n,k,p;
		long long ans=1;
		scanf("%d%d%d",&n,&k,&p);
// 		cout<<n<<" "<<k<<" "<<p<<"\n";
		init(n,k,phi(p));
		for(int i=1;i<=n;i++)
		{
			scanf("%d",&a[i]);
			int x=a[i];
			for(int j=1;prime[j]*prime[j]<=x;j++)
			{
				if(x%prime[j]==0)
				{
					q[prime[j]].push_back(i);
					while(x%prime[j]==0)x/=prime[j];
			 	}
			}
			if(x!=1)q[x].push_back(i);
		}
		for(int i=1;prime[i]<=n;i++)
		{
			while(q[prime[i]].size()>=k)
			{
				int m=q[prime[i]].size();
//				cout<<m<<" "<<k<<"\n";
//				for(int j=0;j<m;j++)
//				cout<<q[prime[i]][j]<<" ";
				ans=mul(ans,ksm(prime[i],C[m][k],p),p);
				cnt=0;
				for(int j=q[prime[i]][0];j<m;j++)
				zhan[++cnt]=q[prime[i]][j];
				q[prime[i]].clear();
				for(int j=1;j<=cnt;j++)
				{
					a[zhan[j]]/=prime[i];
					if(a[zhan[j]]%prime[i]==0)q[prime[i]].push_back(zhan[j]);
				}
//				cout<<ans<<"\n";
//				cout<<q[prime[i]].size()<<"\n";
			}
			q[prime[i]].clear();
		}
		cout<<ans<<"\n";
	}
	return 0;
}

过了样例,交上去直接WA

经过排查发现

for(int i=1;prime[i]<=n;i++)
		{
			while(q[prime[i]].size()>=k)
			{
				int m=q[prime[i]].size();

这一块,n代表的是有几个数而不是当前数的最大值,开了一个新变量maxn用来存储a[i]的最大值,又得到一份代码

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[40010];
int prime[1000010],tot;
bool bj[10000010];
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1)res=(res+a)%p;
		a=(a+a)%p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
long long phi(long long x)
{
	long long ans=x;
	for(int i=1;i<=tot;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
		}
		if(x<prime[i])break;
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=k;j++)
	C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);
	getprime(1e7);
	int t;
	scanf("%d",&t);
	while(t--)
	{
		int n,k,p;
		long long ans=1;
		scanf("%d%d%d",&n,&k,&p);
//		cout<<n<<" "<<k<<" "<<p<<"\n";
		init(n,k,phi(p));
		int maxn=0;
		for(int i=1;i<=n;i++)
		{
			scanf("%d",&a[i]);
			int x=a[i];
			for(int j=1;prime[j]*prime[j]<=x;j++)
			{
				if(x%prime[j]==0)
				{
					q[prime[j]].push_back(i);
					maxn=max(maxn,prime[j]);
					while(x%prime[j]==0)x/=prime[j];
			 	}
			}
			if(x!=1)
			{
				q[x].push_back(i);
				maxn=max(maxn,x);
			}
//			cout<<x<<"\n";
		}
		for(int i=1;prime[i]<=maxn;i++)
		{
//			cout<<"cs"<<prime[i]<<" "<<q[prime[i]].size()<<"\n";
			while(q[prime[i]].size()>=k)
			{
				int m=q[prime[i]].size();
//				cout<<m<<" "<<k<<" "<<prime[i]<<"\n";
//				for(int j=0;j<m;j++)
//				cout<<q[prime[i]][j]<<" ";
				ans=mul(ans,ksm(prime[i],C[m][k],p),p);
				cnt=0;
				for(int j=q[prime[i]][0];j<m;j++)
				zhan[++cnt]=q[prime[i]][j];
				q[prime[i]].clear();
				for(int j=1;j<=cnt;j++)
				{
					a[zhan[j]]/=prime[i];
					if(a[zhan[j]]%prime[i]==0)q[prime[i]].push_back(zhan[j]);
				}
//				cout<<ans<<"\n";
//				cout<<q[prime[i]].size()<<"\n";
			}
			q[prime[i]].clear();
		}
		cout<<ans<<"\n";
	}
	return 0;
}

交上去还是WA,继续排查,发现

cnt=0;
for(int j=q[prime[i]][0];j<m;j++)
    zhan[++cnt]=q[prime[i]][j];

这一块本来是想将当前vector中存储的值取出来,放在zhan数组里,在接下来进行除以质数的操作,但我不知道当时自己在想什么,j的起始条件设为了那个奇怪的东西,此处应该设为0

于是这段改为了

 cnt=0;
 for(int j=0;j<m;j++)
 zhan[++cnt]=q[prime[i]][j];

继续提交,这次终于没WA,它T了……

继续找原因

出题人在比赛时发过公告此题卡常,开始怀疑自己被卡常了

加入快读

inline int read()
{
        int res=0;
        char c=getchar();
        while(c<'0'||c>'9')c=getchar();
        while(c>='0'&&c<='9')
        {
                res=(res<<1)+(res<<3)+c-'0';
        c=getchar();
        }
        return res;
}

同时对中间部分逻辑进行修改,又产生了一份代码

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[40010];
int prime[1000010],tot;
bool bj[10000010];
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1)res=(res+a)%p;
		a=(a+a)%p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
        c=getchar();
	}
	return res;
}
long long phi(long long x)
{
	long long ans=x;
	for(int i=1;i<=tot;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
		}
		if(x<prime[i])break;
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=k;j++)
	C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);
	getprime(1e7);
//	int ks=clock();
	int t;
	scanf("%d",&t);
	int timel=0;
	while(t--)
	{
//		cout<<++timel<<" ";
		long long n,k,p;
		long long ans=1;
		scanf("%lld%lld%lld",&n,&k,&p);
//		cout<<phi(p)<<"\n";
//		cout<<n<<" "<<k<<" "<<p<<"\n";
//int ks=clock();
		init(n,k,phi(p));
//		int js=clock();
//		cout<<js-ks<<"\n";
		int maxn=0;
		for(int i=1;i<=n;i++)
		{
			a[i]=read();
//		scanf("%d",&a[i]);
			int x=a[i];
			for(int j=1;prime[j]*prime[j]<=x;j++)
			{
				if(x%prime[j]==0)
				{
					q[prime[j]].push_back(i);
					maxn=max(maxn,prime[j]);
					while(x%prime[j]==0)x/=prime[j];
			 	}
			}
			if(x!=1)
			{
				q[x].push_back(i);
				maxn=max(maxn,x);
			}
//			cout<<x<<"\n";
		}
		for(int i=1;prime[i]<=maxn;i++)
		{
//			cout<<"cs"<<prime[i]<<" "<<q[prime[i]].size()<<"\n";
			int m=q[prime[i]].size();
			while(m>=k)
			{
//				cout<<m<<" "<<k<<" "<<prime[i]<<"\n";
//				for(int j=0;j<m;j++)
//				cout<<q[prime[i]][j]<<" ";
				ans=mul(ans,ksm(prime[i],C[m][k],p),p);
				cnt=0;
				for(int j=0;j<m;j++)
				zhan[++cnt]=q[prime[i]][j];
				q[prime[i]].clear();
				for(int j=1;j<=cnt;j++)
				{
					a[zhan[j]]/=prime[i];
					if(a[zhan[j]]%prime[i]==0)q[prime[i]].push_back(zhan[j]);
				}
//				cout<<ans<<"\n";
//				cout<<q[prime[i]].size()<<"\n";
				m=q[prime[i]].size();
			}
			q[prime[i]].clear();
		}
		cout<<ans<<"\n";
	}
//	int js=clock();
//	cout<<js-ks<<"\n";
	return 0;
}

主要加入了快读,同时在本地加入clock测试在本地时间,用m替代q[prime[i]].size()来减少对其调用(然而这一做法并没有什么用)

交上去,还是T

看了眼题解,出题人貌似说暴力分解质因数过不去(我就是这么做的),参考题解的思路,改变了中间枚举的方法

for(int i=1;prime[i]<=maxn;i++)
		{
			long long now=0;
			for(int j=prime[i];j<=maxn;j+=prime[i])
			{
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				res+=a[x];
				if(res>=k)
				now=(now+C[res][k])%PHI;
			}
			ans=mul(ans,ksm(prime[i],now,p),p);
//			cout<<ans<<"\n";
		}

此处a[x]存储的信息是大小为x的数有多少个,枚举每个质数以及它们的次幂,对出现次数大于k的进行次幂上的组合计数(注意对phi(p)取模,根据拓展欧拉定理)这里maxn直接仿照题解取了8e4+5,然后得到了如下的代码

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[80010];
int prime[1000010],tot;
bool bj[10000010];
int maxn=8e4+5;
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1)res=(res+a)%p;
		a=(a+a)%p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
long long phi(long long x)
{
	long long ans=x;
	for(int i=1;i<=tot;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
		}
		if(x<prime[i])break;
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=k;j++)
	C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
//vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);
	getprime(1e7);
	int t;
	scanf("%d",&t);
	int timel=0;
	while(t--)
	{
		long long n,k,p;
		long long ans=1;
		scanf("%lld%lld%lld",&n,&k,&p);
		int PHI=phi(p);
		init(n,k,PHI);
		for(int i=0;i<=maxn;i++)
		a[i]=0;
		for(int i=1;i<=n;i++)
		{
			a[read()]++;
		}
//		cout<<"vjb";
		for(int i=1;prime[i]<=maxn;i++)
		{
			long long now=0;
			for(int j=prime[i];j<=maxn;j+=prime[i])
			{
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				res+=a[x];
				if(res>=k)
				now=(now+C[res][k])%PHI;
			}
			ans=mul(ans,ksm(prime[i],now,p),p);
//			cout<<ans<<"\n";
		}
		cout<<ans<<"\n";
	}
	return 0;
}

交上去,直接WA了

经过与std的反复对照发现

long long now=0;
                        for(int j=prime[i];j<=maxn;j+=prime[i])
                        {

这层是用来枚举prime[i]的次幂,但我手一滑写成了+=而不是*=,同时两个8e4的数相乘貌似会爆 int

于是这块改成

 long long now=0;
                        for(long long j=prime[i];j<=maxn;j*=prime[i])
                        {

交上去,这次没WA,它又T了

开始毫无意义的优化,如为了减少对prime[i]的访问,用一个变量o来表示它(真的毫无意义)

最离谱的是我在本地试了组极限数据,貌似结果是我比std跑得快???开始怀疑std出错

将std交了一发过了后,证明std没错,感觉还是自己哪块写炸了,继续debug

先粘一下现在的代码

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[80010];
int prime[1000010],tot;
bool bj[10000010];
int maxn=8e4;
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1)res=(res+a)%p;
		a=(a+a)%p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
long long phi(long long x)
{
	long long ans=x;
	for(int i=1;i<=tot;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
		}
		if(x<prime[i])break;
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=k;j++)
	C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
//vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);

	getprime(1e7);int ks=clock();
	int t;
	scanf("%d",&t);
	int timel=0;
	while(t--)
	{
		long long n,k,p;
		long long ans=1;
		scanf("%lld%lld%lld",&n,&k,&p);
		int PHI=phi(p);
		init(n,k,PHI);
		for(int i=0;i<=maxn;i++)
		a[i]=0;
		for(int i=1;i<=n;i++)
		{
			a[read()]++;
		}
//		cout<<"vjb";
		for(int i=1;prime[i]<=maxn;i++)
		{
			long long now=0;
			int o=prime[i];
			for(long long j=o;j<=maxn;j*=o)
			{
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				res+=a[x];
				if(res>=k)
				now=(now+C[res][k])%PHI;
			}
			ans=mul(ans,ksm(o,now,p),p);
//			cout<<ans<<"\n";
		}
		cout<<ans<<"\n";
	}
	int js=clock();
// 	cout<<js-ks<<"\n";
	return 0;
}

通过与std对比发现

                        }
                        ans=mul(ans,ksm(o,now,p),p);

这句话貌似会被调用很多次,如果now为0那么它会多次乘1,修改为

                  		}
                        if(now)
                        ans=mul(ans,ksm(o,now,p),p);

还是T

以为被卡常,开始加小的优化

                               if(res>=k)
                                now=(now+C[res][k])%PHI;
                        }
                        if(now)
                        ans=mul(ans,ksm(o,now,p),p);
                                if(res>=k)
                                now=(now+C[res][k]);
                                if(now>PHI)now-=PHI;
                        }
                        if(now)ans=mul(ans,ksm(o,now,p),p);

将%PHI改为if加-理论上会快一点,然而还是T,继续优化

        for(int j=1;j<=k;j++)
        C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
        for(int j=1;j<=k;j++)
        {
                C[i][j]=(C[i-1][j]+C[i-1][j-1]);
                if(C[i][j]>p)C[i][j]-=p;
        }

组合数那里加同样的优化,还是T,继续优化

if(b&1)res=(res+a)%p;
                a=(a+a)%p;
                b>>=1;
if(b&1){res=res+a;if(res>p)res-=p;}
                a+=a;
        if(a>p)a-=p;
                b>>=1;

快速乘那块也加同样的优化,还是T,继续

		for(int i=1;i<=n;i++)
        for(int j=1;j<=k;j++)
        {
        for(int i=1;i<=n;i++)
        for(int j=1;j<=min(i,k);j++)
        {

组合数的求法那里加了点小的优化,理论上这里的规模会缩小一半,还是T继续

long long ksm(long long a,long long b,long long p)
long long ksm(long long a,long long b)

突发奇想地认为这里多传了一个参数会慢点,修改后还是T

 						int o=prime[i];
                        for(long long j=o;j<=maxn;j*=o)
                        {
 						int o=prime[i];
                        for(int j=o;j<=maxn;j*=o)
                        {
                                if(now>PHI)now-=PHI;
                        }
                                if(now>PHI)now-=PHI;
                                if(1ll*j*o>maxn)break;
                        }

觉得是这里long long运算相比int慢,改为int去试,还是T

                {
                        a[read()]++;
                }
                {
                        int m=read();
                        a[m]++;
                        maxn=max(maxn,m);
                }

改了对上界maxn的求法,原本是固定为8e4,现改为求当前最大数(虽然说理论上数据可以卡死这个优化,但当时只过了40%的点,尝试了一下)还是T

                                res+=a[x];
                                if(res>=k)
                                now=(now+C[res][k]);
                                res+=a[x];
                                if(res<k)break;
                                now=(now+C[res][k]);

看似这里只是>=和<的区别,但其实下面更优,当前res<k时已经不用再去跑后面的数据了,理论上优化幅度很大

然而还是T

                long long ans=1;
                scanf("%lld%lld%lld",&n,&k,&p);
                int PHI=phi(p);
                long long ans=1;
                n=read();k=read();p=lread();
//              scanf("%lld%lld%lld",&n,&k,&p);
                int PHI=phi(p);

将所有的数据都改为用快读去读,虽然并没有什么作用,还是T

    long long ans=x;
    for(int i=1;i<=tot;i++)
    {

    long long ans=x;
    for(int i=1;prime[i]*prime[i]<=x;i++)
    {

求phi时,改变枚举方法,还是T(回来看这一次修改,不但没什么优化,反而引入了一个新的数据溢出)

    long long ans=x;
    for(int i=1;prime[i]*prime[i]<=x;i++)
    {

    long long ans=x;
    for(int i=1;1ll*prime[i]*prime[i]<=x;i++)
    {

本地测试时发现这里会产生死循环,发现了刚才引入的溢出,不过还是T

 		getprime(1e7);
        getprime(11000000);

在phi出了一次问题后,以为自己边界炸了,加大了枚举质数的范围,不过没什么用

想着继续从优化phi的求法那里入手,加上了miller_rabin去分解phi,代码变动较大,贴上新的代码

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40001][31];
int a[80010];
int prime[1000010],tot;
bool bj[11001000];
long long n,k,p;
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1){res=res+a;if(res>=p)res-=p;}
		a+=a;
        if(a>p)a-=p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
inline long long lread()
{
	long long res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
bool miller_rabin(long long n)
{
	if(n==2)return true;
	if(n<2||!(n&1))return false;
	long long m=n-1;
	int k=0;
	while((m&1)==0)
	{
		k++;
		m>>=1;
	}
	for(int i=0;i<10;i++)
	{
		long long a=rand()%(n-1)+1;
		long long x=ksm(a,m,n);
		long long y=0;
		for(int j=0;j<k;j++)
		{
			y=mul(x,x,n);
			if(y==1&&x!=1&&x!=n-1)return false;
			x=y;
		}
		if(y!=1)return false;
	}
	return true;
}
long long phi(long long x)
{
	long long ans=x;
	if(miller_rabin(x))
	{
		ans=ans/x*(x-1);return ans;
	}
	for(int i=1;1ll*prime[i]*prime[i]<=x;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
			if(miller_rabin(x))break;
		}
		if(x<prime[i])break;
		
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=min(i,k);j++)
	{
		C[i][j]=(C[i-1][j]+C[i-1][j-1]);
		if(C[i][j]>p)C[i][j]-=p;
	}
}
//vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);

	getprime(11000000);int ks=clock();
	int t;
	t=read();
	int timel=0;
	while(t--)
	{
		long long ans=1;
		n=read();k=read();p=lread();
//		scanf("%lld%lld%lld",&n,&k,&p);
		int PHI=phi(p);
		init(n,k,PHI);
		int maxn=0;
		for(int i=1;i<=n;i++)
		{
			int m=read();
			a[m]++;
			maxn=max(maxn,m);
		}
//		cout<<"vjb";
		for(int i=1;prime[i]<=maxn;i++)
		{
			long long now=0;
			int o=prime[i];
			for(int j=o;j<=maxn;j*=o)
			{
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				res+=a[x];
				if(res<k)break;
				now=(now+C[res][k]);
				if(now>PHI)now-=PHI;
				if(1ll*j*o>maxn)break;
			}
			if(now)ans=mul(ans,ksm(o,now,p),p);
//			cout<<ans<<"\n";
		}
		cout<<ans<<"\n";
		for(int i=0;i<=maxn;i++)
		a[i]=0;
	}
//	int js=clock();
//	cout<<js-ks<<"\n";
	return 0;
}

不过还是T

int js=clock();
                if(js-ks>=800)return 0;

认为自己没什么问题后,开始炸测评机,卡边界

跑了将近一页的提交后,得到

 if(js-ks>=13950)return 0;

这样会WA

 if(js-ks>=14000)return 0;

这样会T

其实本来时间卡800ms就差不多了,但牛客测评机跑出来的clock()很奇怪,多次二分后得到这么一个边界

而且WA时跑的时间差不多有400ms,稍微开大点限制后就T了,开始思考程序不是被卡常了,而是出现了死循环

int C[40001][31];
int C[40100][33];

本地测试时发现出现的答案很奇怪,发现数组开小了,影响到了边界的条件,程序中加入了大量的调试点用来判断死循环

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int C[40100][33];
int a[80010];
int prime[1000010],tot;
bool bj[11001000];
long long n,k,p;
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1){res=res+a;if(res>=p)res-=p;}
		a+=a;
        if(a>p)a-=p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
inline long long lread()
{
	long long res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
bool miller_rabin(long long n)
{
	if(n==2)return true;
	if(n<2||!(n&1))return false;
	long long m=n-1;
	int k=0;
	while((m&1)==0)
	{
		k++;
		m>>=1;
	}
	for(int i=0;i<10;i++)
	{
		long long a=rand()%(n-1)+1;
		long long x=ksm(a,m,n);
		long long y=0;
		for(int j=0;j<k;j++)
		{
			y=mul(x,x,n);
			if(y==1&&x!=1&&x!=n-1)return false;
			x=y;
		}
		if(y!=1)return false;
	}
	return true;
}
long long phi(long long x)
{
	long long ans=x;
	if(miller_rabin(x))
	{
		ans=ans/x*(x-1);return ans;
	}
	for(int i=1;1ll*prime[i]*prime[i]<=x;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
			if(miller_rabin(x))break;
		}
		if(x<prime[i])break;
		
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
void init(int n,int k,int p)
{
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=min(i,k);j++)
	{
		C[i][j]=(C[i-1][j]+C[i-1][j-1]);
		if(C[i][j]>p)C[i][j]-=p;
	}
}
//vector<int>q[80000];
int main()
{
// 	freopen("1.in","r",stdin);
	getprime(11000000);
//	int ks=clock();
	int t;
	t=read();
	int timel=0;
	int ks=clock();
	for(int i=0;i<=80000;i++)a[i]=0;
	
//	cout<<"4y38\n";
	while(t--)
	{
		long long ans=1;
		n=read();k=read();p=lread();
//		scanf("%lld%lld%lld",&n,&k,&p);

		int PHI=phi(p);
		
		init(n,k,PHI);
//		for(int i=0;i<=80000;i++)cout<<a[i]<<"\n";
		int maxn=0;
//		for(int i=1;i<=n;i++)
//		cout<<i<<" "<<a[i]<<"\n";
		for(int i=1;i<=n;i++)
		{
			int m=read();
//			cout<<m<<"\n";
			a[m]++;
//			cout<<a[m]<<"\n";
			maxn=max(maxn,m);
		}
//		cout<<"vjb";
		for(int i=1;prime[i]<=maxn;i++)
		{
//		cout<<"4y38\n";
			long long now=0;
			int o=prime[i];
			for(int j=o;j<=maxn;j*=o)
			{
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				{
					res+=a[x];
//					cout<<x<<" "<<a[x]<<"\n";
				}
//				cout<<res<<"\n";
				if(res<k)break;
				now=(now+C[res][k]);
				if(now>PHI)now-=PHI;
				if(1ll*j*o>maxn)break;
			}
			if(now)ans=mul(ans,ksm(o,now,p),p);
//			cout<<ans<<"\n";
		}
		cout<<ans<<"\n";
		for(int i=0;i<=maxn;i++)
		a[i]=0;
//		int js=clock();
//		if(js-ks>=800)return 0;
	}
//	int js=clock();
//	cout<<js-ks<<"\n";
	return 0;
}

改完组合数的问题后,又交了一发,还是T

int C[40100][33];
long long C[40100][33];

继续在组合数那里找问题,突然意识到求得数%phi(p)意义下的组合数,而p的范围是le14爆了int,组合数里会出现负值,紧接着就会影响下面快速幂那里幂次传入一个负数,有可能导致死循环,交上去,还是T

开始以为还是自己组合数那里求慢了,抄了std求组合数的部分

#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
void init(int n,int m,long long p)
{
        rep(i,0,n) rep(j,*C[i]=1,min(i,m))
        {
                C[i][j]=C[i-1][j]+C[i-1][j-1];
                if(C[i][j]>=p)C[i][j]-=p;
        }
//      C[0][0]=1;
//      for(int i=1;i<=n;i++)
//      C[i][0]=1;
//      for(int i=1;i<=n;i++)
//      for(int j=1;j<=min(i,k);j++)
//      {
//              C[i][j]=(C[i-1][j]+C[i-1][j-1]);
//              if(C[i][j]>p)C[i][j]-=p;
//      }
}

然而并没有什么用

又在本地加了许多调试信息,跑多次对拍后突然发现

 int PHI=phi(p);

我写了这么一个东西,明明刚刚分析过组合数出现负值会死循环,现在可好,模数都可能是负的,测试点中随便来个大质数就会炸

改完后居然过了,时间 223 ms,粘一下最后的代码(含有大量调试信息)

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
long long C[40100][33];
int a[80010];
int prime[1000010],tot;
bool bj[11001000];
long long n,k,p;
long long mul(long long a,long long b,long long p)
{
	long long res=0;
	while(b)
	{
		if(b&1){res=res+a;if(res>=p)res-=p;}
		a+=a;
        if(a>p)a-=p;
		b>>=1;
	}
	return res;
}
long long ksm(long long a,long long b,long long p)
{
	long long res=1;
	while(b)
	{
		if(b&1)res=mul(res,a,p);
		a=mul(a,a,p);
		b>>=1;
	}
	return res;
}
int zhan[1000010],cnt;
void getprime(int x)
{
	for(int i=2;i<=x;i++)
	{
		if(!bj[i])prime[++tot]=i;
		for(int j=1;i*prime[j]<=x&&j<=tot;j++)
		{
			bj[i*prime[j]]=1;
			if(i%prime[j]==0)break;
		}
	}
}
inline int read()
{
	int res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
inline long long lread()
{
	long long res=0;
	char c=getchar();
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')
	{
		res=(res<<1)+(res<<3)+c-'0';
		c=getchar();
	}
	return res;
}
bool miller_rabin(long long n)
{
	if(n==2)return true;
	if(n<2||!(n&1))return false;
	long long m=n-1;
	int k=0;
	while((m&1)==0)
	{
		k++;
		m>>=1;
	}
	for(int i=0;i<10;i++)
	{
		long long a=rand()%(n-1)+1;
		long long x=ksm(a,m,n);
		long long y=0;
		for(int j=0;j<k;j++)
		{
			y=mul(x,x,n);
			if(y==1&&x!=1&&x!=n-1)return false;
			x=y;
		}
		if(y!=1)return false;
	}
	return true;
}
long long phi(long long x)
{
	long long ans=x;
	if(miller_rabin(x))
	{
		ans=ans/x*(x-1);return ans;
	}
	for(int i=1;1ll*prime[i]*prime[i]<=x;i++)
	{
		if(x%prime[i]==0)
		{
			ans=ans/prime[i]*(prime[i]-1);
			while(x%prime[i]==0)x/=prime[i];
			if(miller_rabin(x))break;
		}
		if(x<prime[i])break;
		
	}
	if(x!=1)
	ans=ans/x*(x-1);
	return ans;
}
//#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
void init(int n,int k,long long p)
{
//	rep(i,0,n) rep(j,*C[i]=1,min(i,m)) 
//	{
//		C[i][j]=C[i-1][j]+C[i-1][j-1];
//		if(C[i][j]>=p)C[i][j]-=p;
//	}
	C[0][0]=1;
	for(int i=1;i<=n;i++)
	C[i][0]=1;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=min(i,k);j++)
	{
		C[i][j]=(C[i-1][j]+C[i-1][j-1]);
		if(C[i][j]>p)C[i][j]-=p;
	}
}
//vector<int>q[80000];
int main()
{
//	freopen("1.in","r",stdin);int ks=clock();
	getprime(11000000);
//int js=clock();
//cout<<js-ks<<"\n";	
	int t;
	t=read();
	int timel=0;
//	int ks=clock();
	for(int i=0;i<=80000;i++)a[i]=0;
	
//	cout<<"4y38\n";
	while(t--)
	{
		long long ans=1;
		n=read();k=read();p=lread();
//		scanf("%lld%lld%lld",&n,&k,&p);
		long long PHI=phi(p);
		init(n,k,PHI);
//		cout<<PHI<<"\n";
//		for(int i=0;i<=80000;i++)cout<<a[i]<<"\n";
		int maxn=0;
//		for(int i=1;i<=n;i++)
//		cout<<i<<" "<<a[i]<<"\n";

		for(int i=1;i<=n;i++)
		{
			int m=read();
//			cout<<m<<"\n";
			a[m]++;
//			cout<<a[m]<<"\n";
			maxn=max(maxn,m);
		}
		
//		cout<<"vjb";
		for(int i=1;prime[i]<=maxn;i++)
		{
//		cout<<"4y38\n";

			long long now=0;
			int o=prime[i];
			for(int j=o;j<=maxn;j*=o)
			{
				
				int res=0;
				for(int x=j;x<=maxn;x+=j)
				{
//					cout<<res<<" "<<x<<" "<<maxn<<" "<<j<<"\n";
					res+=a[x];
//					cout<<x<<" "<<a[x]<<"\n";
				}
//				cout<<res<<"\n";
				if(res<k)break;
				now=(now+C[res][k]);
				if(now>PHI)now-=PHI;
				if(1ll*j*o>maxn)break;
			}
//			cout<<now<<"\n";
			if(now)ans=mul(ans,ksm(o,now,p),p);
//			cout<<ans<<"\n";
		}
		cout<<ans<<"\n";
		for(int i=0;i<=maxn;i++)
		a[i]=0;
//		int js=clock();
//		if(js-ks>=800)return 0;
	}
//	int
//	js=clock();
//	cout<<js-ks<<"\n";
	return 0;
}

随后顺手试了之前的思路,在组合数不炸的情况下 跑过了80%的测试点,这次应该真的是复杂度偏高