Description

给定一棵n个节点的树,边有边权(可能为负)。

你需要删掉恰好K条边,再连上恰好K条边权为0的边,并保证连完边后这还是一棵树,求这棵树的最大的最长路长度。
[JZOJ5641] 林克卡特树【树形DP】【凸优化】_#include

Solution

转化模型

删K条边再加K条边,那么对于新树上的一条路径,一定可以用原树上K+1条点不相交的链来表示它(注意一个单点也可以看做一条链,因此路径长不够的时候可以用单点来补)

那么问题就转化为在原树上选恰好K+1条链,使总长最大。

注意到如果我们直接设[JZOJ5641] 林克卡特树【树形DP】【凸优化】_#define_02为当前做完以[JZOJ5641] 林克卡特树【树形DP】【凸优化】_#include_03为根的子树,选了j条链,i这个点接的链的情况(没有/被子树中不超过一条链接上/被子树中不超过两条链接上(即它不能再接出父亲))

显然这样状态数是[JZOJ5641] 林克卡特树【树形DP】【凸优化】_#include_04的,不能满足要求

考虑优化:

感受一下,如果把K作为横轴,x=K下的最优答案作为纵轴,那平面上就有了n个点,这n个点构成了一个凸包。

也就是说,这是单峰的,并且相邻点连线斜率不增。

可以反证,对于选链的情况进行讨论,再分析一下增量,发现更优的选择一定会在更早选,具体不再赘述。

我们发现,如果没有链数限制,我们可以在[JZOJ5641] 林克卡特树【树形DP】【凸优化】_凸优化_05的时间内求出整体最优解(即凸包最高点的值,也可以求出它用了多少条链)

假如我们将整个凸包整体旋转某个角度(横坐标不变),使我们需要的x=K的点成为整体最高点,那我们就可以快速算出来了。

整体旋转某个角度,等同于用一条过原点的直线去切这个凸包
如下图

[JZOJ5641] 林克卡特树【树形DP】【凸优化】_树形DP_06


实际上,就是对于每一个点,将纵坐标减去横坐标*这条的直线的斜率

考虑这样做在原题目中的体现,相当于每选一条链还需要另外支付一个代价(斜率),求最优解。

此时我们可以二分这个代价(斜率),求出最高点,看最高点的横坐标(选的链数)<K还是>K,来调整二分的斜率。

直到最后,我们二分出了一个恰当的斜率,使得最高点的横坐标为K,那么输出纵坐标+K*斜率即可。

注意到对于这一题,相邻两个点的横坐标差一定为1,且最优值都是整数,那么我们也只需要在整数中二分斜率。

有一种特殊情况,就是连着很多个点斜率相同,K又不在两端,那么此时算出的最优解横坐标不一定是=k的,但我们发现,这些点纵坐标相等,加上横坐标*斜率以后,仍然能得到横坐标=k时的解,因此是没有问题的。

这样总复杂度就是[JZOJ5641] 林克卡特树【树形DP】【凸优化】_斜率_07

Code

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <cstring>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 300005
#define LL long long
using namespace std;
LL f[N][2],g[N][2],h[N][2],c[2],d[2],e[2],pr[2*N],mid;
int dt[2*N],nt[2*N],fs[N],n,m,l;
void dp(int k,int fa)
{
h[k][0]=f[k][0]=-mid,h[k][1]=f[k][1]=1;
g[k][0]=0,g[k][1]=0;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa)
{
dp(p,k);
fo(j,0,1) c[j]=f[k][j],d[j]=g[k][j],e[j]=h[k][j];
if(d[0]<f[k][0]+f[p][0]+pr[i]+mid)
{
d[0]=f[k][0]+f[p][0]+pr[i]+mid;
d[1]=f[k][1]+f[p][1]-1;
}
if(c[0]<h[k][0]+f[p][0]+pr[i]+mid)
{
c[0]=h[k][0]+f[p][0]+pr[i]+mid;
c[1]=h[k][1]+f[p][1]-1;
}
if(c[0]<f[k][0]+g[p][0]) c[0]=f[k][0]+g[p][0],c[1]=f[k][1]+g[p][1];
if(d[0]<g[k][0]+g[p][0]) d[0]=g[k][0]+g[p][0],d[1]=g[k][1]+g[p][1];
if(e[0]<h[k][0]+g[p][0]) e[0]=h[k][0]+g[p][0],e[1]=h[k][1]+g[p][1];
fo(j,0,1) f[k][j]=c[j],g[k][j]=d[j],h[k][j]=e[j];
if(f[k][0]<h[k][0]) f[k][0]=h[k][0],f[k][1]=h[k][1];
if(g[k][0]<f[k][0]) g[k][0]=f[k][0],g[k][1]=f[k][1];
}
}
if(f[k][0]<h[k][0]) f[k][0]=h[k][0],f[k][1]=h[k][1];
if(g[k][0]<f[k][0]) g[k][0]=f[k][0],g[k][1]=f[k][1];
}
void link(int x,int y,int z)
{
nt[++m]=fs[x];
dt[fs[x]=m]=y;
pr[m]=z;
}
int main()
{
cin>>n>>l;
fo(i,1,n-1)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
link(x,y,z),link(y,x,z);
}
l++;
LL x=-1e6,y=1e8;
while(x<y)
{
mid=(x+y)/2;
fo(i,1,n) f[i][0]=g[i][0]=h[i][0]=-1e9;
dp(1,0);
if(g[1][1]==l)
{
printf("%lld\n",g[1][0]+mid*(LL)l);
return 0;
}
if(g[1][1]<l) y=mid-1;
else x=mid+1;
}
printf("%lld\n",g[1][0]+mid*(LL)l);
}