HDU 4812 D Tree

思路

点对距离相等并且要求输出字典序最小的点对,距离相等不就是点分治裸题了嘛,

照着这个思路出发我们只要记录下所有点对是满足要求的,然后再去找字典序最小的点对就行了,

接下来就是考虑如何求最小点对了,按照路径相加的原理,这里我们处理出所有路径到当前根节点的乘积出来,然后把这个数与k相除得到我们要找的点,当然,这个除法是模意义下的逆元乘法。

然后这题就变成了点分治裸题了。

代码

/*
  Author : lifehappy
*/
#pragma comment(linker,"/STACK:102400000,102400000")
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>

#define mp make_pair
#define pb push_back
#define endl '\n'
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define ls rt << 1
#define rs rt << 1 | 1

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

const double pi = acos(-1.0);
const double eps = 1e-7;
const int inf = 0x3f3f3f3f;

inline ll read() {
    ll f = 1, x = 0;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-')    f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return f * x;
}

const int N = 1e5 + 10, mod = 1e6 + 3;

int head[N], to[N << 1], nex[N << 1], cnt;

int sz[N], msz[N], value[N], visit[N], n, m, sum, root, tot;

int flag[mod + 10], inv[mod + 10], now[N], pos[N], ans1, ans2;

void add(int x, int y) {
    to[cnt] = y;
    nex[cnt] = head[x];
    head[x] = cnt++;
}

void init() {
    for(int i = 1; i <= n; i++) {
        visit[i] = head[i] = 0;
    }
    cnt = 1, ans1 = ans2 = inf;
}

void get_root(int rt, int fa) {
    sz[rt] = 1, msz[rt] = 0;
    for(int i = head[rt]; i; i = nex[i]) {
        if(to[i] == fa || visit[to[i]]) continue;
        get_root(to[i], rt);
        sz[rt] += sz[to[i]];
        msz[rt] = max(msz[rt], sz[to[i]]);
    }
    msz[rt] = max(msz[rt], sum - sz[rt]);
    if(msz[rt] < msz[root]) root = rt;
}

void get_mult(int rt, int fa, int mult) {
    now[++tot] = mult, pos[tot] = rt;
    for(int i = head[rt]; i; i = nex[i]) {
        if(to[i] == fa || visit[to[i]]) continue;
        get_mult(to[i], rt, 1ll * mult * value[to[i]] % mod);
    }
}

void calc(int rt) {
    tot = 0;
    for(int i = head[rt]; i; i = nex[i]) {
        if(visit[to[i]]) continue;
        int st = tot + 1;
        get_mult(to[i], rt, value[to[i]]);
        int ed = tot;
        for(int j = st; j <= ed; j++) {
            int temp = 1ll * now[j] * value[rt] % mod;
            temp = 1ll * m * inv[temp] % mod;
            int x1 = flag[temp], x2 = pos[j];
            if(x1 > x2) swap(x1, x2);
            if(x1 == inf || x2 == inf) continue;
            if(x1 < ans1 || (x1 == ans1 && x2 < ans2)) ans1 = x1, ans2 = x2;
        }
        for(int j = st; j <= ed; j++) {
            flag[now[j]] = min(pos[j], flag[now[j]]);//不断记录已有的存在的路径的点的最小编号。
        }
    }
    for(int i = 1; i <= tot; i++) {//最重要的重置操作。
        // cout << now[i] << " ";
        flag[now[i]] = inf;
    }
    // cout << endl;
}

void solve(int rt) {
    // cout << rt << endl;
    visit[rt] = 1;
    flag[1] = rt;//这个flag一定要设置为当前根节点。
    calc(rt);
    for(int i = head[rt]; i; i = nex[i]) {
        if(visit[to[i]]) continue;
        sum = sz[to[i]], root = 0, msz[0] = inf;
        get_root(to[i], rt);
        solve(root);
    }
}

int main() {
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    // ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    inv[1] = 1;
    for(int i = 2; i < mod; i++) {
        inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    }
    memset(flag, 0x3f, sizeof flag);
    while(scanf("%d %d", &n, &m) != EOF) {
        init();
        for(int i = 1; i <= n; i++) {
            scanf("%d", &value[i]);
        }
        for(int i = 1; i < n; i++) {
            int x, y;
            scanf("%d %d", &x, &y);
            add(x, y);
            add(y, x);
        }
        sum = n, root = 0, msz[0] = inf;
        get_root(1, 0);
        solve(root);
        if(ans1 == inf || ans2 == inf) {
            puts("No solution");
            continue;
        }
        printf("%d %d\n", ans1, ans2);
    }
	return 0;
}