题意:
无向连通图 G 有 n 个点,n−1 条边。
点从 1 到 n 依次编号,编号为 i 的点的权值为 $W_i$ ,每条边的长度均为 1。
图上两点 (u,v) 的距离定义为 u 点到 v 点的最短距离。对于图 G 上的点对 (u,v),若它们的距离为 2,则它们之间会产生$W_v*W_u$ 的联合权值。
请问图 G 上所有可产生联合权值的有序点对中,联合权值最大的是多少?所有联合权值之和是多少?
权值和的结果对10007取模
首先,暴力。。。对每个点爆搜,限制dep=2,统计答案
//暴力70分~~~~~~~~~~~~ #include<cstdio> #include<iostream> #include<cstring> #include<cctype> #include<vector> #include<algorithm> using namespace std; #define int long long #define olinr return #define nmr 205050 #define _ 0 #define love_nmr 0 #define mod 10007 #define DB double int w[nmr]; vector<int> G[nmr]; int n; inline int read() { int x=0,f=1; char ch=getchar(); while(!isdigit(ch)) { if(ch=='-') f=-f; ch=getchar(); } while(isdigit(ch)) { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } inline void put(int x) { if(x<0) { x=-x; putchar('-'); } if(x>9) put(x/10); putchar(x%10+'0'); } int ans; int tot; inline void dfs(int fir,int x,int fa,int dep) { if(dep==2) { tot+=(w[fir]%mod)*(w[x]%mod)%mod; tot%=mod; ans=max(ans,w[fir]*w[x]); return; } for(int i=0;i<G[x].size();i++) { int go=G[x][i]; if(go!=fa) dfs(fir,go,x,dep+1); } } signed main() { n=read(); int x,y; for(int i=1;i<n;i++) { x=read(); y=read(); G[x].push_back(y); G[y].push_back(x); } for(int i=1;i<=n;i++) w[i]=read(); for(int i=1;i<=n;i++) dfs(i,i,0,0); put(ans); putchar(' '); put(tot); olinr ~~(0^_^0)+love_nmr; }
打完暴力交上去才发现,woc,(u,v)与(v,u)不等同!!
也就是说,统计权值和应该算两次!而上面的暴力正好做到了QAQ
正解。。。。。。666
因为长度为2,也就是3给点,我们枚举中间点,如图,1,2,3,4,5,6任意组合都成立。。。
所以,对于每个点,直接枚举它直接相连的点,两两组合就行了QAQ
对于权值和,比如我们进行到了4
那么它可以与1,2,3,组合
权值和+=($4*1+4*2+4*3=(1+2+3)*4$)
额
这可以用前缀和。。。
所以,
对每个点
维护一个前缀和
维护一个前缀最大值
最后别忘$*2$和取模就行了
#include<cstdio> #include<iostream> #include<cstring> #include<cctype> #include<vector> #include<algorithm> using namespace std; #define int long long #define olinr return #define nmr 205050 #define _ 0 #define love_nmr 0 #define mod 10007 #define DB double int w[nmr]; vector<int> G[nmr]; int n; int s[nmr]; int maxx[nmr]; inline int read() { int x=0,f=1; char ch=getchar(); while(!isdigit(ch)) { if(ch=='-') f=-f; ch=getchar(); } while(isdigit(ch)) { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } inline void put(int x) { if(x<0) { x=-x; putchar('-'); } if(x>9) put(x/10); putchar(x%10+'0'); } int ans; int tot; int cnt; signed main() { n=read(); int x,y; for(int i=1;i<n;i++) { x=read(); y=read(); G[x].push_back(y); G[y].push_back(x); } for(int i=1;i<=n;i++) w[i]=read(); for(int i=1;i<=n;i++) { int siz=G[i].size(); for(int j=0;j<siz;j++) { int go=G[i][j]; (s[j+1]=(s[j]+w[go])%mod)%=mod; maxx[j+1]=max(maxx[j],w[go]); (tot+=(s[j]*w[go])%mod)%=mod; ans=max(ans,maxx[j]*w[go]); } } put(ans); putchar(' '); put((tot<<1)%mod); olinr ~~(0^_^0)+love_nmr; }