FFT字符串匹配

定义字符串下标从 0 0 0,开始,有文本串 A A A长度为 n n n,模式串 B B B长度为 m m m,我们可以考虑一个函数 f ( x , y ) = A ( x ) − B ( y ) f(x, y) = A(x) - B(y) f(x,y)=A(x)B(y)

我们设 F ( x ) ( x ≥ m − 1 ) = ∑ i = 0 m − 1 f ( x − m + 1 + i , i ) F(x)(x \ge m - 1) = \sum\limits_{i = 0} ^{m - 1} f(x - m + 1 + i, i) F(x)(xm1)=i=0m1f(xm+1+i,i),由定义显然可以得到如果 F ( x ) = 0 F(x) = 0 F(x)=0,则 A [ x − m + 1 , x ] = B A[x - m + 1, x] = B A[xm+1,x]=B也就是两个字符匹配上了,

但是考虑 " a b " , " b a " "ab", "ba" "ab","ba"两个字符串,他们也是匹配的,我们稍微修改一下 f ( x , y ) f(x, y) f(x,y)函数令其为: ( A ( x ) − B ( y ) ) 2 (A(x) - B(y)) ^ 2 (A(x)B(y))2,这样这个函数就没有问题了。

我们考虑翻转一下 B B B串,令其为 S S S,则有 B ( i ) = S ( m − i − 1 ) B(i) = S(m - i - 1) B(i)=S(mi1)
F ( x ) = ∑ i = 0 m − 1 ( A ( x − m + 1 + i ) − S ( m − i − 1 ) ) 2 F ( x ) = ∑ i = 0 m − 1 S ( m − i − 1 ) 2 + ∑ i = 0 m − 1 A ( x − m + 1 + i ) 2 − 2 × ∑ i = 0 m − 1 A ( x − m + 1 + i ) S ( m − i − 1 ) F(x) = \sum\limits_{i = 0} ^{m - 1} \left(A(x - m + 1 + i) - S(m - i - 1)\right) ^ 2\\ F(x) = \sum_{i = 0} ^{m - 1} S(m - i - 1) ^ 2 + \sum_{i = 0} ^{m - 1} A(x - m + 1 + i) ^ 2 - 2 \times \sum_{i = 0} ^{m - 1} A(x - m + 1 + i) S(m - i - 1)\\ F(x)=i=0m1(A(xm+1+i)S(mi1))2F(x)=i=0m1S(mi1)2+i=0m1A(xm+1+i)22×i=0m1A(xm+1+i)S(mi1)
第一项是一个定值,第二可以 O ( n ) O(n) O(n)预处理,然后前缀和 O ( 1 ) O(1) O(1)得到,第三项不难发现是一个卷积的形式,所以可以通过 F F T FFT FFT得到,整体复杂度 O ( n log ⁡ n ) O(n \log n) O(nlogn)

以上我们已经可以解决当模式串的字符串匹配了,尽管复杂度不如 K M P KMP KMP优秀,但是我们考虑一个缺项字符串匹配:

a*b

aebr*ob

我们考虑重新设计 f ( x , y ) f(x, y) f(x,y)函数,定义 f ( x , y ) = ( A ( x ) − B ( y ) ) 2 A ( x ) B ( y ) f(x, y) = (A(x) - B(y)) ^ 2 A(x) B(y) f(x,y)=(A(x)B(y))2A(x)B(y),同样的考虑翻转 B B B串,有 B ( i ) = S ( m − 1 − i ) B(i) = S(m - 1 - i) B(i)=S(m1i)
F ( x ) = ∑ i = 0 m − 1 ( A ( x − m + 1 + i ) − S ( m − 1 − i ) ) 2 A ( x − m + 1 + i ) S ( m − 1 − i ) ∑ i = 0 m − 1 A ( x − m + 1 + i ) S ( m − 1 − i ) 3 + ∑ i = 0 m − 1 A ( x − m + 1 + i ) 3 S ( m − 1 − i ) − 2 × ∑ i = 0 m − 1 A ( x − m + 1 + i ) 2 S ( m − 1 − i ) 2 F(x) = \sum_{i = 0} ^{m - 1} (A(x - m + 1 + i) - S(m - 1 - i)) ^ 2 A(x - m + 1 + i) S(m - 1 - i)\\ \sum_{i = 0} ^{m - 1}A(x - m + 1 + i) S(m - 1 - i) ^ 3 + \sum_{i = 0} ^{m - 1} A(x - m + 1 + i) ^ 3 S(m - 1 - i) - 2 \times \sum_{i = 0} ^{m - 1} A(x - m + 1 + i) ^ 2 S(m - 1 - i) ^ 2\\ F(x)=i=0m1(A(xm+1+i)S(m1i))2A(xm+1+i)S(m1i)i=0m1A(xm+1+i)S(m1i)3+i=0m1A(xm+1+i)3S(m1i)2×i=0m1A(xm+1+i)2S(m1i)2
容易发现这里是三个多项式相加的形式,所以只要做三次 F F T FFT FFT即可得到答案,放一个模板题

#include <bits/stdc++.h>

using namespace std;

struct Complex {
  double r, i;

  Complex(double _r = 0, double _i = 0) : r(_r), i(_i) {}
};

Complex operator + (const Complex &a, const Complex &b) {
  return Complex(a.r + b.r, a.i + b.i);
}

Complex operator - (const Complex &a, const Complex &b) {
  return Complex(a.r - b.r, a.i - b.i);
}

Complex operator * (const Complex &a, const Complex &b) {
  return Complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
}

Complex operator / (const Complex &a, const Complex &b) {
  return Complex((a.r * b.r + a.i * b.i) / (b.r * b.r + b.i * b.i), (a.i * b.r - a.r * b.i) / (b.r * b.r + b.i * b.i));
}

Complex operator * (const Complex &a, const double &b) {
  return Complex(a.r * b, a.i * b);
}

const int N = 2e6 + 10;

int r[N];

void get_r(int lim) {
  for (int i = 0; i < lim; i++) {
    r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1);
  }
}

void FFT(Complex *f, int lim, int rev) {
  for (int i = 0; i < lim; i++) {
    if (i < r[i]) {
      swap(f[i], f[r[i]]);
    }
  }
  const double pi = acos(-1.0);
  for (int mid = 1; mid < lim; mid <<= 1) {
    Complex wn = Complex(cos(pi / mid), rev * sin(pi / mid));
    for (int len = mid << 1, cur = 0; cur < lim; cur += len) {
      Complex w = Complex(1, 0);
      for (int k = 0; k < mid; k++, w = w * wn) {
        Complex x = f[cur + k], y = w * f[cur + mid + k];
        f[cur + k] = x + y, f[cur + mid + k] = x - y;
      }
    }
  }
  if (rev == -1) {
    for (int i = 0; i < lim; i++) {
      f[i].r /= lim;
    }
  }
}

// const int N = 1e6 + 10;

Complex a[N], b[N], c[N];

char str1[N], str2[N];

int A[N], S[N], n, m, lim;

int main() {
  // freopen("in.txt", "r", stdin);
  // freopen("out.txt", "w", stdout);
  scanf("%d %d %s %s", &m, &n, str2, str1);
  for (int i = 0; i < n; i++) {
    A[i] = str1[i] == '*' ? 0 : str1[i] - 'a' + 1;
  }
  for (int i = 0; i < m; i++) {
    S[i] = str2[m - i - 1] == '*' ? 0 : str2[m - i - 1] - 'a' + 1;
  }
  lim = 1;
  while (lim < n + m) {
    lim <<= 1;
  }
  get_r(lim);
  for (int i = 0; i < lim; i++) {
    b[i] = Complex(A[i], 0);
    c[i] = Complex(S[i] * S[i] * S[i], 0);
  }
  FFT(b, lim, 1), FFT(c, lim, 1);
  for (int i = 0; i < lim; i++) {
    a[i] = a[i] + b[i] * c[i];
  }
  for (int i = 0; i < lim; i++) {
    b[i] = Complex(A[i] * A[i] * A[i], 0);
    c[i] = Complex(S[i], 0);
  }
  FFT(b, lim, 1), FFT(c, lim, 1);
  for (int i = 0; i < lim; i++) {
    a[i] = a[i] + b[i] * c[i];
  }
  for (int i = 0; i < lim; i++) {
    b[i] = Complex(A[i] * A[i], 0);
    c[i] = Complex(S[i] * S[i], 0);
  }
  FFT(b, lim, 1), FFT(c, lim, 1);
  for (int i = 0; i < lim; i++) {
    a[i] = a[i] - 2 * b[i] * c[i];
  }
  FFT(a, lim, -1);
  vector<int> ans;
  for (int i = m - 1; i < n; i++) {
    if ((long long)(a[i].r + 0.5) == 0) {
      ans.push_back(i - m + 2);
    }
  }
  printf("%d\n", ans.size());
  for (auto it : ans) {
    printf("%d ", it);
  }
  puts("");
  return 0;
}