原题链接:http://poj.org/problem?id=2763
题意:n个点,n-1条点的连线,数据保证任意两点可达,无环,接下来q行操作,两种形式,0 u 表示查询该人到u的时间;1 i w 表示第i条路的时间改为w。
分析:只能用RMQ来做,有点细节处理,都在注释里。
#define _CRT_SECURE_NO_DEPRECATE
#include<iostream>
#include<vector>
#include<cstring>
#include<queue>
#include<stack>
#include<algorithm>
#include<cmath>
#define INF 99999999
#define eps 0.0001
#define N ((1<<12)+10)
using namespace std;
struct Edge
{
int v;
int index;
int next;
};
int n, q, s;
int index;
int cnt;
Edge edge[2 * 100005];
int head[100005];
int dis[100005];
int father[100005];//父节点
int time[100005];//第i条路的时间
int node[2 * 100005];//保存第i次访问的节点
int first[100005];//i这个节点是第几次访问到的
int depth[2 * 100005];//第i次访问的节点的深度
int dp[2 * 100005][25];//dp[i][j]表示从第i次访问开始,连续2^j个访问内,哪次访问的节点深度最小
bool vis[100005];
void add(int u, int v, int index)
{
edge[cnt].v = v; edge[cnt].index = index; edge[cnt].next = head[u]; head[u] = cnt++;
edge[cnt].v = u; edge[cnt].index = index; edge[cnt].next = head[v]; head[v] = cnt++;
}
void dfs(int u, int dep, int fa)
{
index++;
vis[u] = 1;
first[u] = index;
node[index] = u;
depth[index] = dep;
dis[u] = dep;
father[u] = fa;
for (int i = head[u]; i !=-1; i=edge[i].next)
{
int v = edge[i].v;
if (!vis[v])
{
dfs(v, dep + time[edge[i].index], u);
index++;
node[index] = u;
depth[index] = dep;
}
}
}
void ST(int n)
{
int k = log((double)n) / log(2.0);
for (int i = 1; i <= n; i++)
dp[i][0] = i;
for (int j = 1; j <= k; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++)
{
int a = dp[i][j - 1];
int b = dp[i + (1 << (j - 1))][j - 1];
if (depth[a] < depth[b])
dp[i][j] = a;
else
dp[i][j] = b;
}
}
int LCA(int u, int v)
{
int i = first[u];
int j = first[v];
if (i > j)
swap(i, j);
int k = log(j - i + 1.0) / log(2.0);
int a = dp[i][k];
int b = dp[j - (1 << k) + 1][k];
int x = (depth[a] < depth[b]) ? a : b;
return node[x];
}
void solve(int u,int v)
{
int lca = LCA(u, v);
printf("%d\n", dis[u] + dis[v] - 2 * dis[lca]);
}
void update(int u,int t,int fa)
{
for (int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].v;
if (father[v] == fa)
{
dis[v] += t;
update(v, t, v);
}
}
}
int main()
{
int u, v;
int op;
while (~scanf("%d%d%d", &n, &q, &s))
{
index = cnt = 0;
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
depth[1] = 0;
dis[1] = 0;
memset(dp, 0, sizeof(dp));
for (int i = 1; i < n; i++)
{
scanf("%d%d%d", &u, &v, &time[i]);
add(u, v, i);
}
dfs(1, 0, 1);
ST(index);
while (q--)
{
scanf("%d", &op);
if (op)
{
int ith, w;
scanf("%d%d", &ith, &w);
int pos = (ith - 1) * 2;
int x = edge[pos].v;
int y = edge[pos + 1].v;
int ww = time[ith];
time[ith] = w;//数据要更新,记住
if (dis[x] > dis[y])
swap(x, y);
int t = w - ww;
dis[y] += t;
update(y, t, y);
}
else
{
int x;
scanf("%d", &x);
solve(s, x);
s = x;
}
}
}
return 0;
}