https://www.luogu.org/problemnew/show/P4016

题目描述

GG 公司有 nn 个沿铁路运输线环形排列的仓库,每个仓库存储的货物数量不等。如何用最少搬运量可以使 nn 个仓库的库存数量相同。搬运货物时,只能在相邻的仓库之间搬运。

输入输出格式

输入格式:

文件的第 11 行中有 11 个正整数 nn ,表示有 nn 个仓库。

第 22 行中有 nn 个正整数,表示 nn 个仓库的库存量。

输出格式:

输出最少搬运量。

思路

这题可用最小费用最大流来解决。
显然,移动完后,所有仓库的值都是平均值。

所以,我们把每个仓库的值减去平均值(x)。
然后开始建图

首先建立超级源点s汇点t。
如果x>0,则s向i连一条费用为0,流量为x的边。
如果x<0,则i向t连一条费用为0,流量为x的边。
我们还要保证各个仓库连通,所以每个仓库向两边连流量为∞,费用为1的边

代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int INF=0x3f3f3f3f,maxn=207;
int cnt=1,d,s,t,n;
bool b[maxn];
int list[maxn*12],dis[maxn],pre[maxn],flow[maxn],a[maxn],last[maxn*12];
struct E
{
    int to,next,flow,dis;
}e[maxn*12];
void add(int u,int v,int flow,int dis)
{
    e[++cnt].next=list[u]; e[cnt].to=v; e[cnt].flow=flow; e[cnt].dis=dis;
    list[u]=cnt;
//   printf("u=%d v=%d\n",u,v);
}
bool spfa(int s,int t)
{
    memset(dis,0x3f,sizeof(dis));
    memset(flow,0x3f,sizeof(flow));
    memset(b,0,sizeof(b));
    queue<int> q;
    q.push(s); b[s]=1; dis[s]=0; pre[t]=-1;
    while(!q.empty())
    {
        int u=q.front(); 
        q.pop();
        b[u]=0;
        for(int i=list[u]; i!=-1; i=e[i].next)
        {
            if(e[i].flow>0&&dis[e[i].to]>dis[u]+e[i].dis)
            {
 //             printf("flow=%d d1=%d d2=%d ds=%d\n",e[i].flow,dis[u],e[i].dis,dis[e[i].to]); getchar();
                dis[e[i].to]=dis[u]+e[i].dis;
                pre[e[i].to]=u;
                last[e[i].to]=i;
                flow[e[i].to]=min(flow[u],e[i].flow);
                if (!b[e[i].to])
                {
                    b[e[i].to]=1;
                    q.push(e[i].to);
                }
            }
        }
    }
    return pre[t]!=-1;
}
int f()
{
    int mincost=0;
    while(spfa(s,t))
    {
        int u=t;
        mincost+=flow[t]*dis[t];
        while(u!=s)
        {
            e[last[u]].flow-=flow[t];
            e[last[u]^1].flow+=flow[t];
            u=pre[u];
        }
    }
    return mincost;
}
int main()
{
    memset(list,-1,sizeof(list));
    scanf("%d",&n);
    for(int i=1; i<=n; i++) scanf("%d",&a[i]),d+=a[i];
    d/=n;
    s=n+1;t=n+2;
    for(int i=1; i<=n; i++)
    {
        int x=a[i]-d;
        if(x>0) add(s,i,x,0),add(i,s,0,0);
        if(x<0) add(i,t,-x,0),add(t,i,0,0);
    }
    for(int i=2; i<=n; i++) 
        add(i,i-1,INF,1),add(i-1,i,0,-1),add(i-1,i,INF,1),add(i,i-1,0,-1);
    add(1,n,INF,1),add(n,1,0,-1),add(n,1,INF,1),add(1,n,0,-1);
    printf("%d",f());
}