题目大意:给定一棵树,多次将某个点设为关键点或取消关键点,求虚树中边长总和的二倍
Orz wyfcyx
首先我们考虑树链的并(每个点到根节点的链的并集)怎么求
将虚树中的所有点按照DFS序排序,将每个点的深度统计入答案,将相邻两个点之间的LCA的深度从答案中扣除,就是所有点到根的链的并集的长度
但是我们要求的是虚树中的边长总和,因此我们还要减掉所有点LCA的深度
现在要求动态维护,因此我们用set维护一下虚树的DFS序即可
各种开小数组忘开long long忘改lld我这是怎么了……
#include <set>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define M 100100
using namespace std;
struct abcd{
int to,f,next;
}table[M<<1];
int head[M],tot;
int n,m;
long long ans;
int pos[M],log_2[M<<1],T;
long long dis[M<<1],a[M<<1][18];
bool status[M<<1];
set<int> seq;
void Add(int x,int y,int z)
{
table[++tot].to=y;
table[tot].f=z;
table[tot].next=head[x];
head[x]=tot;
}
void DFS(int x,int from)
{
int i;
a[pos[x]=++T][0]=dis[x];
for(i=head[x];i;i=table[i].next)
if(table[i].to!=from)
{
dis[table[i].to]=dis[x]+table[i].f;
DFS(table[i].to,x);
a[++T][0]=dis[x];
}
}
long long RMQ(int x,int y)
{
int len=log_2[y-x+1];
return min(a[x][len],a[y-(1<<len)+1][len]);
}
void Insert(int x)
{
set<int>::iterator it=seq.insert(x).first;
ans+=a[x][0];
if(seq.size()==1)
return ;
if(it==seq.begin())
{
set<int>::iterator secc=it;++secc;
ans-=RMQ(*it,*secc);
return ;
}
if((++it)--==seq.end())
{
set<int>::iterator pred=it;--pred;
ans-=RMQ(*pred,*it);
return ;
}
set<int>::iterator secc=it;++secc;
set<int>::iterator pred=it;--pred;
ans+=RMQ(*pred,*secc);
ans-=RMQ(*it,*secc);
ans-=RMQ(*pred,*it);
}
void Erase(int x)
{
set<int>::iterator it=seq.find(x);
ans-=a[x][0];
if(seq.size()==1)
{
seq.erase(it);
return ;
}
if(it==seq.begin())
{
set<int>::iterator secc=it;++secc;
ans+=RMQ(*it,*secc);
seq.erase(it);
return ;
}
if((++it)--==seq.end())
{
set<int>::iterator pred=it;--pred;
ans+=RMQ(*pred,*it);
seq.erase(it);
return ;
}
set<int>::iterator secc=it;++secc;
set<int>::iterator pred=it;--pred;
ans-=RMQ(*pred,*secc);
ans+=RMQ(*it,*secc);
ans+=RMQ(*pred,*it);
seq.erase(it);
}
long long LCA_Distance()
{
if(seq.size()==0)
return 0;
set<int>::iterator st=seq.begin();
set<int>::iterator ed=seq.end();ed--;
return RMQ(*st,*ed);
}
int main()
{
int i,j,x,y,z;
cin>>n>>m;
for(i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
Add(x,y,z);Add(y,x,z);
}
DFS(1,0);
for(i=2;i<=T;i++)
log_2[i]=log_2[i>>1]+1;
for(j=1;j<=log_2[T];j++)
for(i=1;i+(1<<j)-1<=T;i++)
a[i][j]=min(a[i][j-1],a[i+(1<<j-1)][j-1]);
for(i=1;i<=m;i++)
{
scanf("%d",&x);
if(!status[x])
status[x]=true,Insert(pos[x]);
else
status[x]=false,Erase(pos[x]);
printf("%lld\n",ans-LCA_Distance()<<1);
}
return 0;
}