最短路+奇妙solution

Problem

给一个\(n\)个点,\(m\)条边的有向图。
给\(k\)个特殊点\(K_1,K_2,\cdot,K_k\)
求\(k\)个特殊点中两两最短路的最小值。
数据范围:
P5304 [GXOI/GZOI2019]旅行者_i++

Solution

Thinking 1

Floyd...好像暴力都没得打。/kk

Thinking 2

DAG!可以拓扑搞。
然而并不会!

Thinking 3

思路真的很妙。
建超级起点\(S\)和超级终点\(T\)。
考虑枚举每个二进制位:
枚举关键点:

  • 若当前关键点的当前二进制位为1,则加边\(S \to K_i\),边权为0.
  • 否则,加边\(K_i \to T\),边权为0.
    对于当前来说,最小的最短路为\(S\to T\)的最短路,易证。
    然后对于每个二进制位还要反着做一遍。因为是有向边,\(x \to y\)的最短路与\(y \to x\)的最短路可能不同。

这样做为什么是对的呢?
其实就是需要证明一个问题:对于每个特殊点对\(K_i,K_j\),他们都至少一次被分到了不同的集合。
那很显然,因为特殊点互不相同,那么至少有一个二进制位不同,就会被分到不同的集合。

# include <bits/stdc++.h>
using namespace std;
const int N = 100005,inf = 1e9 + 7;
int Test;
int n,m,k;
struct edge
{
    int v,w; 
    edge() {}
    edge(int _v,int _w) : v(_v),w(_w) {}
};
vector <edge> g[N];
int K[N];
int dis[N]; bool vis[N];
int S,T;
void dij(void)
{
    priority_queue <pair<int,int>,vector <pair<int,int> > ,greater<pair<int,int> > > q;
    for(int i = 1; i <= n + 2; i++) dis[i] = inf,vis[i] = 0;
    dis[S] = 0;
    q.push(make_pair(0,S));
    while(!q.empty())
    {
        int x = q.top().second;q.pop();
        if(vis[x]) continue;
        vis[x] = 1;
        for(int i = 0; i < (int)g[x].size(); i++)
        {
            int v = g[x][i].v;
            if(dis[v] > dis[x] + g[x][i].w)
            {
                dis[v] = dis[x] + g[x][i].w;
                q.push(make_pair(dis[v],v));
            }
        }
    }
    // for(int i = 1; i <= n + 2; i++) printf("dis[%d] = %d\n",i,dis[i]);
    return;
}
int main(void)
{
    scanf("%d",&Test);
    while(Test--)
    {
        scanf("%d%d%d",&n,&m,&k);
        for(int i = 1; i <= n; i++) g[i].clear();
        for(int i = 1; i <= m; i++)
        {
            int x,y,z; scanf("%d%d%d",&x,&y,&z);
            g[x].push_back(edge(y,z));
        }
        for(int i = 1; i <= k; i++)
        {
            scanf("%d",&K[i]);
        }
        S = n + 1, T = n + 2;
        int ans = inf;
        for(int i = 0; i == 0 || (1 << (i - 1)) <= n; i++)
        {
            g[S].clear(); vector <int> S1;
            for(int j = 1; j <= k; j++)
            {   
                if(K[j] >> i & 1)
                {
                    g[S].push_back(edge(K[j],0));
                }
                else g[K[j]].push_back(edge(T,0)),S1.push_back(K[j]);
            }
            dij();
            ans = min(ans,dis[T]);
            for(int j = 0; j < (int)S1.size(); j++)
            {
                // printf("del = %d\n",S1[j]);
                g[S1[j]].pop_back();
            }
            S1.clear();
            g[S].clear();
            for(int j = 1; j <= k; j++)
            {
                if(!(K[j] >> i & 1))
                {
                    g[S].push_back(edge(K[j],0));
                }
                else g[K[j]].push_back(edge(T,0)),S1.push_back(K[j]);
            }
            dij();
            ans = min(ans,dis[T]);
            for(int j = 0; j < (int)S1.size(); j++)
            {
                g[S1[j]].pop_back();
            }
            S1.clear();
        }
        printf("%d\n",ans);
    }
    return 0;
}