也是一道保序回归的题,但思路不同于论文中模板题
考虑两个开口向上的二次函数$f(x)$和$g(x)$,求任意实数$x,y$满足$x\le y$且最小化$f(x)+g(y)$,这个最小值可以分类讨论求出:
1.若$f(x)$最小值位置小于等于$g(x)$的最小值位置,显然都取最小值即可;
2.若$f(x)$最小值位置大于$g(x)$的最小值位置,再对$x$的位置分类讨论——
(1)若$x$的位置在$g(x)$最小值的右侧,根据二次函数的性质应取$y=x$;
(2)若$x$的位置在$g(x)$最小值的左侧,显然$y$应取$g(x)$的最小值,那么$x$也应等于$y$
由此,我们可以得到此时最小值必然为$x=y$
(事实上,对于任意在最小值两侧单调的函数$f(x)$和$g(x)$,都具有此性质)
每一个位置就是一个这样的二次函数,对于相邻两个二次函数也就是求上述实数对$(x,y)$
考虑一个暴力的做法:不断找到相邻两个二次函数,使得前者的最小值大于后者的最小值,并根据第2种情况,将两者合并为同一个二次函数$f(x)+g(x)$,最终每一个二次函数最小值之和即为答案
关于合并的顺序是否影响答案,由于每一次合并都保证答案最优,因此最终答案也一定最优,当然也可以证(kou)明(hu)不会影响答案
所谓影响答案,影响的是相邻两点“是否合并”这件事情的不同,同时合并后最小值一定是向中间靠的,因此合并后一定“更可以合并”,由此不难证明合并顺序不影响答案
这一做法的具体实现可以通过单调栈,将前面的二次函数都求出,每次插入一个函数并考虑是否需要与栈顶的二次函数合并即可
对于修改$(x,y)$,先对单调栈作预处理,即用主席树维护出每一个前缀或后缀的单调栈的区间和(对于删除只需要记录栈大小即可)
根据合并顺序的任意,可以先将$[1,x)$和$(x,n]$这两段都合并(合并后的信息在主席树中维护),之后考虑求出左右最终与$x$合并的段数$l$和$r$(注意这里方便表述使用了数量,而代码中用的是下标)
称一对$(l,r)$合法当且仅当将这$l+r+1$个二次函数(包括$x$自己)合并后,与左右两个单调栈内剩下的二次函数最小值单调不下降
合法的$(l,r)$具有单调性,具体来说就是$(l,r)$合法则$(l+1,r)$和$(l,r+1)$都合法
但合法仅仅只是必要条件,问题在于:可以通过与右边单调的部分去合并,使得其值增大,从而与左边原来不单调的部分也单调,这样是不正确的
换言之,我们需要保证最后一次对$l$和$r$的合并都是不单调的合并,即对于$(l,r-1)$,右端点不单调;对于$(l-1,r)$,左端点不单调
具体实现中,为了方便二分,可以先选择$(l,r)$合法以及$(l,r-1)$右端点不单调这两个条件,确定出最小的$r$(对于确定的$l$),之后判定$(l,r)$左端点是否单调(要求单调)
接下来,在外面再套一层对$l$的二分,同样找到最小的合法的$l$即保证了$(l-1,r)$左端点不单调的条件
如果二分后用主席树做复杂度为$o(q\log^{3}n)$,在$r$上的二分可以直接在主席树上二分,因此复杂度为$o(q\log^{2}n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 100005 4 #define mod 998244353 5 #define ll long long 6 #define mid (l+r>>1) 7 int n,m,x,y,top,ans,rt[2][N],sz[2][N]; 8 int ksm(int n,int m){ 9 int s=n,ans=1; 10 while (m){ 11 if (m&1)ans=1LL*ans*s%mod; 12 s=1LL*s*s%mod; 13 m>>=1; 14 } 15 return ans; 16 } 17 struct fun{ 18 int a,c; 19 ll b; 20 bool operator < (const fun &k)const{ 21 //-b/(2*a)<-k.b/(2*k.a) 22 return b*k.a>a*k.b; 23 } 24 fun operator + (const fun &k)const{ 25 return fun{a+k.a,(c+k.c)%mod,b+k.b}; 26 } 27 fun operator - (const fun &k)const{ 28 return fun{a-k.a,(c+mod-k.c)%mod,b-k.b}; 29 } 30 int get_min(){ 31 int bb=(b%mod+mod)%mod; 32 return (c+mod-1LL*bb*bb%mod*ksm(4*a,mod-2)%mod)%mod; 33 } 34 }a[N],st[N]; 35 struct Seg{ 36 int V,ls[N*40],rs[N*40],sum[N*40]; 37 fun right[N*40],f[N*40]; 38 int New(int k){ 39 V++; 40 ls[V]=ls[k]; 41 rs[V]=rs[k]; 42 sum[V]=sum[k]; 43 right[V]=right[k]; 44 f[V]=f[k]; 45 return V; 46 } 47 void update(int &k,int l,int r,int x,fun y){ 48 k=New(k); 49 if (l==r){ 50 sum[k]=y.get_min(); 51 right[k]=f[k]=y; 52 return; 53 } 54 if (x<=mid)update(ls[k],l,mid,x,y); 55 else update(rs[k],mid+1,r,x,y); 56 if (right[rs[k]].a)right[k]=right[rs[k]]; 57 else right[k]=right[ls[k]]; 58 sum[k]=(sum[ls[k]]+sum[rs[k]])%mod; 59 f[k]=f[ls[k]]+f[rs[k]]; 60 } 61 fun query_fun(int k,int l,int r,int x,int y){ 62 if ((!k)||(l>y)||(x>r))return {0,0,0}; 63 if ((x<=l)&&(r<=y))return f[k]; 64 return query_fun(ls[k],l,mid,x,y)+query_fun(rs[k],mid+1,r,x,y); 65 } 66 int query_sum(int k,int l,int r,int x,int y){ 67 if ((!k)||(l>y)||(x>r))return 0; 68 if ((x<=l)&&(r<=y))return sum[k]; 69 return (query_sum(ls[k],l,mid,x,y)+query_sum(rs[k],mid+1,r,x,y))%mod; 70 } 71 int find(int k,int l,int r,fun x){ 72 if (l==r)return l; 73 if (x+f[rs[k]]<right[ls[k]])return find(rs[k],mid+1,r,x); 74 return find(ls[k],l,mid,x+f[rs[k]]); 75 } 76 }T[2]; 77 bool check(int l,int x,int y){ 78 fun o={1,(int)(1LL*y*y%mod),-2*y}; 79 o=o+T[0].query_fun(rt[0][x-1],1,n,l,sz[0][x-1]); 80 int r=T[1].find(rt[1][x+1],1,n,o); 81 if (o<T[1].right[rt[1][x+1]])r=sz[1][x+1]+1; 82 o=o+T[1].query_fun(rt[1][x+1],1,n,r,sz[1][x+1]); 83 return ((l==1)||(T[0].query_fun(rt[0][x-1],1,n,l-1,l-1)<o)); 84 } 85 int calc(int l,int x,int y){ 86 fun o={1,(int)(1LL*y*y%mod),-2*y}; 87 o=o+T[0].query_fun(rt[0][x-1],1,n,l,sz[0][x-1]); 88 int r=T[1].find(rt[1][x+1],1,n,o); 89 if (o<T[1].right[rt[1][x+1]])r=sz[1][x+1]+1; 90 o=o+T[1].query_fun(rt[1][x+1],1,n,r,sz[1][x+1]); 91 int ans=o.get_min(); 92 ans=(ans+T[0].query_sum(rt[0][x-1],1,n,1,l-1))%mod; 93 ans=(ans+T[1].query_sum(rt[1][x+1],1,n,1,r-1))%mod; 94 return ans; 95 } 96 int main(){ 97 scanf("%d%d",&n,&m); 98 for(int i=1;i<=n;i++){ 99 scanf("%d",&x); 100 a[i]=fun{1,(int)(1LL*x*x%mod),-2*x}; 101 } 102 for(int i=1;i<=n;i++){ 103 st[++top]=a[i]; 104 rt[0][i]=rt[0][i-1]; 105 while ((top>1)&&(st[top]<st[top-1])){ 106 T[0].update(rt[0][i],1,n,top,{0,0,0}); 107 st[top-1]=st[top-1]+st[top]; 108 top--; 109 } 110 sz[0][i]=top; 111 T[0].update(rt[0][i],1,n,sz[0][i],st[top]); 112 } 113 top=0; 114 for(int i=n;i;i--){ 115 st[++top]=a[i]; 116 rt[1][i]=rt[1][i+1]; 117 while ((top>1)&&(st[top-1]<st[top])){ 118 T[1].update(rt[1][i],1,n,top,{0,0,0}); 119 st[top-1]=st[top-1]+st[top]; 120 top--; 121 } 122 sz[1][i]=top; 123 T[1].update(rt[1][i],1,n,sz[1][i],st[top]); 124 } 125 printf("%d\n",T[0].sum[rt[0][n]]); 126 for(int i=1;i<=m;i++){ 127 scanf("%d%d",&x,&y); 128 int l=1,r=sz[0][x-1]+1; 129 while (l<r){ 130 int mi=(l+r+1>>1); 131 if (check(mi,x,y))l=mi; 132 else r=mi-1; 133 } 134 printf("%d\n",calc(l,x,y)); 135 } 136 }