树链剖分裸题....
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define lson o << 1, L, mid
#define rson o << 1 | 1, mid+1, R
#define ls o << 1
#define rs o << 1 | 1
const int INF = 0x3f3f3f3f * 2;
const int maxn = 500005;
const int maxm = 500005;
int size[maxn];
int son[maxn];
int dep[maxn];
int top[maxn];
int fa[maxn];
int mpp[maxn];
int w[maxn];
int a[maxn];
int n, m, z;
struct Edge
{
int v;
Edge *next;
}E[maxm], *H[maxn], *edges;
struct node
{
int tmax, lmax, rmax, sum, lazy;
node() {}
node(int tmax, int lmax, int rmax, int sum, int lazy) : tmax(tmax), lmax(lmax), rmax(rmax), sum(sum), lazy(lazy) {}
}tree[maxn << 2];
void init()
{
z = 0;
edges = E;
memset(H, 0, sizeof H);
}
void addedges(int u, int v)
{
edges->v = v;
edges->next = H[u];
H[u] = edges++;
}
void pushup(int o)
{
tree[o].sum = tree[ls].sum + tree[rs].sum;
tree[o].tmax = max(tree[ls].tmax, tree[rs].tmax);
tree[o].tmax = max(tree[o].tmax, tree[ls].rmax + tree[rs].lmax);
tree[o].lmax = max(tree[ls].lmax, tree[ls].sum + tree[rs].lmax);
tree[o].rmax = max(tree[rs].rmax, tree[rs].sum + tree[ls].rmax);
}
void pushdown(int o, int L, int R)
{
if(tree[o].lazy != INF) {
int mid = (L + R) >> 1, t;
t = tree[o].lazy * (mid - L + 1);
tree[ls] = node(max(tree[o].lazy, t), max(tree[o].lazy, t), max(tree[o].lazy, t), t, tree[o].lazy);
t = tree[o].lazy * (R - mid);
tree[rs] = node(max(tree[o].lazy, t), max(tree[o].lazy, t), max(tree[o].lazy, t), t, tree[o].lazy);
tree[o].lazy = INF;
}
}
node merge(node a, node b)
{
node ans;
ans.sum = a.sum + b.sum;
ans.tmax = max(a.tmax, b.tmax);
ans.tmax = max(ans.tmax, a.rmax + b.lmax);
ans.lmax = max(a.lmax, a.sum + b.lmax);
ans.rmax = max(b.rmax, b.sum + a.rmax);
return ans;
}
void build(int o, int L, int R)
{
tree[o] = node(0, 0, 0, 0, INF);
if(L == R) {
tree[o] = node(a[mpp[L]], a[mpp[L]], a[mpp[L]], a[mpp[L]], INF);
return;
}
int mid = (L + R) >> 1;
build(lson);
build(rson);
pushup(o);
}
void update(int o, int L, int R, int ql, int qr, int val)
{
if(ql <= L && qr >= R) {
int t = (R - L + 1) * val;
tree[o] = node(max(val, t), max(val, t), max(val, t), t, val);
return;
}
pushdown(o, L, R);
int mid = (L + R) >> 1;
if(ql <= mid) update(lson, ql, qr, val);
if(qr > mid) update(rson, ql, qr, val);
pushup(o);
}
node query(int o, int L, int R, int ql, int qr)
{
if(ql <= L && qr >= R) return tree[o];
pushdown(o, L, R);
int mid = (L + R) >> 1;
node ans;
if(ql > mid) ans = query(rson, ql, qr);
else if(qr <= mid) ans = query(lson, ql, qr);
else ans = merge(query(lson, ql, qr), query(rson, ql, qr));
pushup(o);
return ans;
}
void dfs1(int u)
{
size[u] = 1, son[u] = 0;
for(Edge *e = H[u]; e; e = e->next) {
if(e->v != fa[u]) {
dep[e->v] = dep[u] + 1;
fa[e->v] = u;
dfs1(e->v);
size[u] += size[e->v];
if(size[son[u]] < size[e->v]) son[u] = e->v;
}
}
}
void dfs2(int u, int tp)
{
w[u] = ++z, top[u] = tp;
if(son[u]) dfs2(son[u], tp);
for(Edge *e = H[u]; e; e = e->next) {
if(e->v != fa[u] && e->v != son[u]) dfs2(e->v, e->v);
}
}
void solve1(int a, int b, int c)
{
int f1 = top[a], f2 = top[b];
while(f1 != f2) {
if(dep[f1] < dep[f2]) {
swap(a, b);
swap(f1, f2);
}
update(1, 1, n, w[f1], w[a], c);
a = fa[f1], f1 = top[a];
}
if(dep[a] > dep[b]) swap(a, b);
update(1, 1, n, w[a], w[b], c);
}
void solve2(int a, int b)
{
int f1 = top[a], f2 = top[b];
node ans1, ans2;
int flag1 = 0, flag2 = 0;
while(f1 != f2) {
if(dep[f1] > dep[f2]) {
if(flag1 == 0) flag1 = 1, ans1 = query(1, 1, n, w[f1], w[a]);
else ans1 = merge(query(1, 1, n, w[f1], w[a]), ans1);
a = fa[f1], f1 = top[a];
}
else {
if(flag2 == 0) flag2 = 1, ans2 = query(1, 1, n, w[f2], w[b]);
else ans2 = merge(query(1, 1, n, w[f2], w[b]), ans2);
b = fa[f2], f2 = top[b];
}
}
if(dep[a] > dep[b]) {
if(flag1 == 0) flag1 = 1, ans1 = query(1, 1, n, w[b], w[a]);
else ans1 = merge(query(1, 1, n, w[b], w[a]), ans1);
}
else {
if(flag2 == 0) flag2 = 1, ans2 = query(1, 1, n, w[a], w[b]);
else ans2 = merge(query(1, 1, n, w[a], w[b]), ans2);
}
int res = 0;
if(flag1 == 0 && flag2 == 1) {
res = ans2.tmax;
}
else if(flag1 == 1 && flag2 == 0) {
res = ans1.tmax;
}
else {
res = max(ans1.tmax, ans2.tmax);
res = max(res, ans1.lmax + ans2.lmax);
}
printf("%d\n", res);
}
void work()
{
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
addedges(u, v);
addedges(v, u);
}
dfs1(1);
dfs2(1, 1);
for(int i = 1; i <= n; i++) mpp[w[i]] = i;
build(1, 1, n);
while(m--) {
int op, a, b, c;
scanf("%d%d%d%d", &op, &a, &b, &c);
if(op == 1) solve1(a, b, c);
if(op == 2) solve2(a, b);
}
}
int main()
{
//freopen("data", "r", stdin);
while(scanf("%d%d", &n, &m) != EOF) {
init();
work();
}
return 0;
}