题目描述

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,分为三种:

操作 1 :把某个节点 x 的点权增加 a 。

操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。

操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

输入格式:

第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1 行每行两个正整数 from, to , 表示该树中存在一条边 (from, to) 。再接下来 M 行,每行分别表示一次操作。

其中第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

输出格式:

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

输入样例#1: 


5 5 1 2 3 4 5 1 2 1 4 2 3 2 5 3 3 1 2 1 3 5 2 1 2 3 3

输出样例#1: 


6 9 13

对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。


DFS序维护

对于单点修改,对整颗子树的贡献为val 于是将整个子树+val (存入第二棵线段树)

对于子树修改,对子树的贡献为(dep[y]-dep[x]+1)*val (y为子树中的一点) 

于是我们将将整颗子树+(1-dep[x])*val (存入第二颗线段树) 然后在将val存入第一棵线段树

为什么呢,因为单点查询时,只需将第二棵线段树的值+第一棵的*dep[要查的点]就可以了

查询即为单点查询,具体看代码就懂了


#include<bits/stdc++.h>
#define LL long long
#define N 100005
using namespace std;
int first[N],next[M],to[M],tot;
int a[N],st[N],ed[N],sign;
int n,m,dep[N]; LL w[N];
struct Node{LL val,tag;}t1[N<<2];
struct Node1{LL val,tag;}t2[N<<2];
int read(){
int cnt=0,f=1;char ch=0;
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))cnt=cnt*10+(ch-'0'),ch=getchar();
return cnt*f;
}
void add(int x,int y){
next[++tot]=first[x],first[x]=tot,to[tot]=y;
}
void dfs(int u,int f){
st[u]=++sign;
for(int i=first[u];i;i=next[i]){
int t=to[i]; if(t==f) continue;
dep[t]=dep[u]+1,w[t]=(LL)w[u]+a[t];
dfs(t,u);
}ed[u]=sign;
}
void Pushdown(int o,int l,int r){
if(t1[o].tag){
int mid=(l+r)>>1;
int lLen=mid-l+1,rLen=r-mid;
t1[o<<1].val+=(LL)t1[o].tag*lLen;
t1[o<<1|1].val+=(LL)t1[o].tag*rLen;
t1[o<<1].tag+=t1[o].tag;
t1[o<<1|1].tag+=t1[o].tag;
t1[o].tag=0;
}
if(t2[o].tag){
int mid=(l+r)>>1;
int lLen=mid-l+1,rLen=r-mid;
t2[o<<1].val+=(LL)t2[o].tag*lLen;
t2[o<<1|1].val+=(LL)t2[o].tag*rLen;
t2[o<<1].tag+=t2[o].tag;
t2[o<<1|1].tag+=t2[o].tag;
t2[o].tag=0;
}
}
void Pushup(int o){
t1[o].val=t1[o<<1].val+t1[o<<1|1].val;
t2[o].val=t2[o<<1].val+t2[o<<1|1].val;
}
void update(int o,int l,int r,int L,int R,LL x,LL y){
if(L<=l&&r<=R){
t1[o].val+=(LL)x*(r-l+1);
t2[o].val+=(LL)y*(r-l+1);
t1[o].tag+=x;
t2[o].tag+=y;
return;
}
Pushdown(o,l,r);
int mid=(l+r)>>1;
if(L<=mid) update(o<<1,l,mid,L,R,x,y);
if(R>mid) update(o<<1|1,mid+1,r,L,R,x,y);
Pushup(o);
}
LL quary(int o,int l,int r,int x,int id){
if(l==r){
return (LL)t1[o].val*dep[x]+t2[o].val;
}
Pushdown(o,l,r);
int mid=(l+r)>>1;
if(id<=mid) return quary(o<<1,l,mid,x,id);
else return quary(o<<1|1,mid+1,r,x,id);
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
add(x,y),add(y,x);
}
dep[1]=1,w[1]=a[1],dfs(1,0);
for(int i=1;i<=m;i++){
int op=read(),x=read();
if(op==1){
int val=read();
update(1,1,n,st[x],ed[x],0,val);
}
if(op==2){
int val=read();
update(1,1,n,st[x],ed[x],val,(LL)val*(1-dep[x]));
}
if(op==3){
printf("%lld\n",quary(1,1,n,x,st[x])+w[x]);
}
}
}