题目链接

题意:把n个数字(A1比其他数字都大)的序列分成三段,每段分别反转,问字典序最小的序列。

分析:因为A1比其他数字都大,所以反转后第一段结尾是很大的数,相当是天然的分割线,第一段可以单独考虑,即求整段的字典序最小的后缀。后面两段不能分开考虑,

例子:

9
8 4 -1 5 0 5 0 2 3
第一步:
3 2 0 5 0 5 -1 4 8 对应输出 -1 4 8
第二步
3 2 0 5 0 5(开始的时候我并没有复制一遍) 对应输出:0 5
第三步
3 2 0 5    对应输出: 3 2 0 5
可以看见这样做是不对的。。
必须要将剩下的字符串复制一遍贴在后面,然后再来求后缀数组。。。
正解:
第一步:
3 2 0 5 0 5 -1 4 8 对应输出 -1 4 8
第二步
3 2 0 5 0 5 3 2 0 5 0 5 对应输出: 0 5 0 5;
第三步
3 2 对应输出:3 2;

所以方法是剩下的反转后+剩下的反转后组成新的串,求sa(这里不用求height,只要扫一遍即可),找到符合条件的字典序最小的后缀(应该是长度为剩下的长度的前缀)。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <string>

typedef long long ll;
const int N = 2e5 + 5;
int a[N];
int rev[N<<1];
int sa[N<<1], rank[N<<1];
int tmp[N<<1];
int n, k;

bool cmp_sa(int i, int j) {
    if (rank[i] != rank[j]) {
        return rank[i] < rank[j];
    } else {
        int ri = i + k <= n ? rank[i+k] : -1;
        int rj = j + k <= n ? rank[j+k] : -1;
        return ri < rj;
    }
}

void get_sa(int *a, int n, int *sa) {
    for (int i=0; i<=n; ++i) {
        sa[i] = i;
        rank[i] = i < n ? a[i] : -1;
    }
    for (k=1; k<=n; k<<=1) {
        std::sort (sa, sa+n+1, cmp_sa);
        tmp[sa[0]] = 0;
        for (int i=1; i<=n; ++i) {
            tmp[sa[i]] = tmp[sa[i-1]] + (cmp_sa (sa[i-1], sa[i]) ? 1 : 0);
        }
        for (int i=0; i<=n; ++i) {
            rank[i] = tmp[i];
        }
    }
}

void run() {
    std::reverse_copy (a, a+n, rev);
    get_sa (rev, n, sa);
    int p1;
    for (int i=0; i<n; ++i) {
        p1 = n - sa[i];
        if (p1 >= 1 && sa[i] >= 2) {
            break;
        }
    }
    int m = n - p1;
    std::reverse_copy (a+p1, a+n, rev);
    std::reverse_copy (a+p1, a+n, rev+m);
    get_sa (rev, m*2, sa);
    int p2;
    for (int i=0; i<=m*2; ++i) {
        p2 = p1 + m - sa[i];
        if (p2 - p1 >= 1 && p2 < n) {
            break;
        }
    }
    std::reverse (a, a+p1);
    std::reverse (a+p1, a+p2);
    std::reverse (a+p2, a+n);
    for (int i=0; i<n; ++i) {
        printf ("%d\n", a[i]);
    }
}

int main() {
    scanf ("%d", &n);
    for (int i=0; i<n; ++i) {
        scanf ("%d", a+i);
    }
    run ();
    return 0;
}

  

 

编译人生,运行世界!