题意
\(sqc(s,i,j,c)\)表示字符串\(s\)中字符\(c\)在区间\([l,r]\)出现次数
求
\[\sum_{c=97}^{122}\sum_{i=1}^n\sum_{j=i}^n sqc(s,i,j,c)^2\]
\[1 \leq T \leq 100\\|s| \leq 10^5 \]
分析
题解的做法太巧妙 没有想到,万能的笨办法:枚举端点
考虑对每个字符分别求和
把字符串中等于当前位置的字符看成1,否则看成0,这样相当于求1的前缀和平方
枚举左端点,我们就可以试图去维护这样的东西,因为1的贡献就是\(\sum i^2\) 可以根据当前1的个数直接求,我们只需要维护0的贡献
把连续的0分段,得到0的贡献就是\(\sum cnt[i]j^2\)
考虑左端点右移对0的贡献的影响就是\(cnt[i]j^2 -> cnt[i](j-1)^2 = cnt[i]j^2 - cnt[i]2j + cnt[i]\)
发现第一项不变,第二项可以继续维护,第三项可以预处理后缀和,第二项维护方法仍然类似,考虑左端点又移可能的影响\(cnt[i]2j -> cnt[i]2(j-1) -> cnt[i]2j - cnt[i]\)
这样就能够\(O(1)\)维护了
代码
调试了很久,这类题一定要想清楚过程再写,否则一定是浪费机时
#include<bits/stdc++.h>
#define pii pair<ll,ll>
#define fi first
#define se second
#define int long long
using namespace std;
typedef long long ll;
inline ll rd(){
ll x;
scanf("%lld",&x);
return x;
}
const int maxn = 1e5 + 5;
const int MOD = 998244353;
inline int mul(int a,int b){
return (ll)a * b % MOD;
}
inline void add(int &a,int b){
a += b;
if(a >= MOD) a -= MOD;
}
inline int ksm(int a,int b = MOD - 2,int m = MOD){
int ans = 1;
int base = a;
while(b){
if(b & 1) ans = mul(ans,base);
base = mul(base,base);
b >>= 1;
}
return ans;
}
char s[maxn];
int Get[maxn];
int suf[maxn];
int Suf[maxn];
int main(){
//freopen("input.txt","r",stdin);
int T = rd();
for(int i = 1;i < maxn;i++){
Get[i] = Get[i - 1] + (ll)i * i % MOD;
if(Get[i] >= MOD) Get[i] -= MOD;
}
while(T--){
scanf("%s",s);
int n = strlen(s);
int ans = 0;
int mulsuf;
for(char ch = 'a';ch <= 'z';ch++){
for(int i = 0;i <= n + 1;i++)
suf[i] = Suf[i] = 0;
mulsuf = 0;
int cnt = 0;
int cnt_1 = 0;
for(int i = 0;i < n;i++)
if(s[i] != ch) suf[cnt]++;
else {
cnt_1++;
if(suf[cnt]) add(mulsuf,(ll)2 * suf[cnt++] * (cnt_1 - 1) % MOD);
}
if(!suf[cnt]) cnt--;
if(s[n - 1] != ch) add(mulsuf,(ll)2 * suf[cnt] * (cnt_1) % MOD);
if(!cnt_1) continue;
Suf[cnt + 1] = 0;
for(int i = cnt;i >= 0;i--){
Suf[i] = Suf[i + 1] + suf[i];
if(Suf[i] >= MOD) Suf[i] -= MOD;
}
/*
if(ch == 'a') {
cout << cnt << '\n';
for(int i = 0;i <= cnt;i++)
cout << suf[i] << ' ';
puts("");
for(int i = 0;i <= cnt;i++)
cout << Suf[i] << ' ';
puts("");
}*/
/*
for(int i = st,j = 1;i <= cnt;i++,j++){
add(mulsuf,2 * j * suf[i]);
}*/
int lst = 0;
int sb = 0;
for(int i = 0;i < n;i++){
if(s[i] == ch) sb++;
add(lst,mul(sb,sb));
}
add(ans,lst);
int fro = !(s[0] == ch);
int Lst = lst - Get[cnt_1];
Lst = (Lst + MOD) % MOD;
for(int i = 1;i < n;i++){
//if(ch == 'a') cout << i << ' ' << mulsuf << ' ' << lst << ' ' << Lst << '\n';
if(s[i - 1] != ch) {
add(ans,lst);
}
else if(s[i - 1] == ch && s[i] == ch) {
cnt_1--;
lst = Get[cnt_1];
Lst -= mulsuf;
add(Lst,MOD);
add(Lst,Suf[fro]);
mulsuf -= (2 * Suf[fro]) % MOD;
add(mulsuf,MOD);
//add(Lst,Suf[fro]);
//fro++;
add(lst,Lst);
add(ans,lst);
}
else if(s[i - 1] == ch && s[i] != ch) {
cnt_1--;
lst = Get[cnt_1];
Lst -= mulsuf;
add(Lst,Suf[fro]);
add(Lst,MOD);
mulsuf -= (2 * Suf[fro]) % MOD;
add(mulsuf,MOD);
//add(Lst,Suf[fro]);
fro++;
add(lst,Lst);
add(ans,lst);
}
}
//cout << ans << '\n';
}
printf("%d\n",ans);
}
}