关于众数,我们通常是对于每个数求出它作为众数的区间数。
考虑某个数 \(x\)。我们可以发现,如果令 \(a_i=x\) 的所有位置有 \(b_i=1\),其余位置有 \(b_i=-1\),则如果某个区间 \([l,r]\) 关于 \(b\) 的区间和 \(>0\) 的话,则有 \(x\) 是 \([l,r]\) 的绝对众数。
考虑区间和可以被拆成前缀和之差。于是我们考虑求出 \(b\) 的前缀和 \(s\) 数组。则,对于位置 \(i\),我们需要找出有多少个 \(j\),满足 \(j<i\land s_j<s_i\)。
这明显是二维数点问题,常规做法是线段树。但是,我们不能对所有颜色全部从头到尾做一遍二维数点,不然复杂度就是 \(O(n^2\log n)\) 的。考虑 \(s_i\) 数组有何性质。
稍加观察可以发现,\(s\) 数组在大部分位置都是有 \(s_i=s_{i-1}-1\) 的,只有在 \(a_i=x\) 处才会出现例外;于是我们就可以考虑相邻的两个 \(a_j=a_i=x\)。它实际在线段树上的贡献,是 \([s_j,s_i)\) 区间 \(+1\),可以直接完成。
然后考虑一个单独的 \(a_i\),我们要做的是在线段树上求前缀和;现在既然是 \([j,i)\) 的区间了,那么我们就对 \([s_j,s_i)\) 中的所有东西都求前缀和。实际上就是对线段树上每个位置乘上了一个系数,直接在线段树上同时维护系数即可。
时间复杂度 \(O(n\log n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,a[500100];
vector<int>v[500100];
ll res;
#define lson x<<1
#define rson x<<1|1
#define mid ((l+r)>=0?(l+r)>>1:(l+r-1)>>1)
struct Segtree{
int tag,sum;
ll mus;
}seg[4001000];
void ADD(int x,int l,int r,int y=1){seg[x].tag+=y,seg[x].sum+=(r-l+1)*y,seg[x].mus+=1ll*((n-l)+(n-r))*(r-l+1)*y/2;}
void pushdown(int x,int l,int r){ADD(lson,l,mid,seg[x].tag),ADD(rson,mid+1,r,seg[x].tag),seg[x].tag=0;}
void pushup(int x){seg[x].sum=seg[lson].sum+seg[rson].sum,seg[x].mus=seg[lson].mus+seg[rson].mus;}
void reset(int x,int l,int r){
seg[x].tag=seg[x].sum=seg[x].mus=0;
if(l==r)return;
if(seg[lson].sum)reset(lson,l,mid);
if(seg[rson].sum)reset(rson,mid+1,r);
}
void modify(int x,int l,int r,int L,int R){
if(l>R||r<L)return;
if(L<=l&&r<=R)ADD(x,l,r);
else pushdown(x,l,r),modify(lson,l,mid,L,R),modify(rson,mid+1,r,L,R),pushup(x);
}
ll querytri(int x,int l,int r,int L,int R){
if(l>R||r<L)return 0;
if(L<=l&&r<=R)return seg[x].mus-1ll*seg[x].sum*(n-R-1);
pushdown(x,l,r);
return querytri(lson,l,mid,L,R)+querytri(rson,mid+1,r,L,R);
}
ll querypla(int x,int l,int r,int L,int R){
if(l>=L)return 0;
if(r<L)return 1ll*seg[x].sum*(R-L+1);
pushdown(x,l,r);
return querypla(lson,l,mid,L,R)+querypla(rson,mid+1,r,L,R);
}
int main(){
scanf("%d%d",&n,&a[0]);
for(int i=0;i<n;i++)v[i].push_back(0);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),v[a[i]].push_back(i);
for(int i=0;i<n;i++)v[i].push_back(n+1);
for(int i=0;i<n;i++){
if(v[i].size()==2)continue;
int las=0;
for(int j=1;j+1<v[i].size();j++){
// printf("ADD:[%d,%d]\n",las-(v[i][j]-v[i][j-1]-1),las);
modify(1,-n,n,las-(v[i][j]-v[i][j-1]-1),las);
las-=(v[i][j]-v[i][j-1]-1)-1;
res+=querytri(1,-n,n,las-(v[i][j+1]-v[i][j]),las-1)+querypla(1,-n,n,las-(v[i][j+1]-v[i][j]),las-1);
// printf("SUM:[%d,%d]:%lld\n",las-(v[i][j+1]-v[i][j]),las-1,query(1,-n,n,las-(v[i][j+1]-v[i][j]),las-1));
}
reset(1,-n,n);
}
printf("%lld\n",res);
return 0;
}