脑残了,这题竟然都不会.
显然,把所有左右端点放在一起排序,然后取中位数是 $k=1$ 时最优的.
$k=2$ 的时候显然距离中点越近越好,所以将中点扔进去,然后枚举中间的分割点,分割点左右就变成两个子问题了.
动态求中位数的话用平衡树/权值线段树维护就行了.
code:
#include <bits/stdc++.h> #define ll long long #define lson s[x].ch[0] #define rson s[x].ch[1] #define N 200009 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int tot,rt1,rt2,n,K,cnt; ll seq[N<<1]; struct data { ll sum,w; int ch[2],f,si; }s[N<<2]; inline int get(int x) { return s[s[x].f].ch[1]==x; } inline void pushup(int x) { s[x].si=s[lson].si+s[rson].si+1; s[x].sum=s[lson].sum+s[rson].sum+s[x].w; } void rotate(int x) { int old=s[x].f,fold=s[old].f,which=get(x); s[old].ch[which]=s[x].ch[which^1]; if(s[old].ch[which]) s[s[old].ch[which]].f=old; s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; if(fold) s[fold].ch[s[fold].ch[1]==old]=x; pushup(old),pushup(x); } void splay(int x,int &tar) { int u=s[tar].f; for(int fa;(fa=s[x].f)!=u;rotate(x)) if(s[fa].f!=u) rotate(get(fa)==get(x)?fa:x); tar=x; } int getkth(int x,int kth) { while(1) { if(s[lson].si+1==kth) break; else if(kth<=s[lson].si) x=lson; else kth-=(s[lson].si+1),x=rson; } return x; } void ins(int &x,int fa,int v) { if(!x) { x=++tot,s[x].w=v,s[x].f=fa,pushup(x); return; } ins(s[x].ch[v>s[x].w],x,v),pushup(x); } int find(int x,int v) { while(1) { if(s[x].w==v) break; else x=s[x].ch[v>s[x].w]; } return x; } void del(int v) { int x=find(rt2,v),l,r; splay(x,rt2),l=s[x].ch[0],r=s[x].ch[1]; if(!l) s[r].f=0,rt2=r; else if(!r) s[l].f=0,rt2=l; else { while(s[l].ch[1]) l=s[l].ch[1]; splay(l,s[x].ch[0]); s[l].f=0,s[l].ch[1]=r,s[r].f=l,rt2=l,pushup(rt2); } } void build(int &x,int fa,int l,int r) { int mid=(l+r)>>1; s[x=++tot].f=fa; s[x].w=seq[mid]; if(mid>l) build(lson,x,l,mid-1); if(r>mid) build(rson,x,mid+1,r); pushup(x); } ll query(int &x) { int u=x; int p=getkth(u,(s[u].si&1)?s[u].si/2+1:s[u].si/2); splay(p,x); ll ans=-s[lson].sum+(ll)s[x].w*s[lson].si-(ll)s[x].w*s[rson].si+s[rson].sum; return ans; } struct node { ll l,r; node(ll l=0,ll r=0):l(l),r(r){} bool operator<(const node b) const { return (l+r)<(b.l+b.r); } }nd[N]; int main() { // setIO("input"); scanf("%d%d",&K,&n); char a[2],b[2]; ll ans=0,x,y,z; for(int i=1;i<=n;++i) { scanf("%s%lld%s%lld",a,&x,b,&y); if(x>y) swap(x,y); if(a[0]==b[0]) ans+=abs(y-x); else ++ans,seq[++cnt]=x,seq[++cnt]=y,nd[cnt/2]=node(x,y); } if(cnt==0) { printf("%lld\n",ans); return 0; } sort(nd+1,nd+1+(cnt/2)); sort(seq+1,seq+1+cnt),build(rt2,0,1,cnt); if(K==1) { printf("%lld\n",ans+query(rt2)); return 0; } ll fin=100000000000000; for(int i=1;i<cnt/2;++i) { del(nd[i].l); del(nd[i].r); ins(rt1,0,nd[i].l); if(i%7==0) splay(tot,rt1); ins(rt1,0,nd[i].r); if(i%7==0) splay(tot,rt1); fin=min(fin,query(rt1)+query(rt2)); } printf("%lld\n",fin+ans); return 0; }