给定一个长度为 \(n\) 的序列和 \(m\) 次操作,每次操作可以把序列所有值 \(x\) 改成值 \(y\) 。请在每次操作后输出序列中相同的值之间的最短距离。例如序列 \([1, 2, 1, 4, 2]\) ,其相同的值之间的最短距离为下标为 \(1\) 和 \(3\) 的值之间的距离,也就是 \(3\) 。如果序列中没有相同的值,输出 2147483647
。
这道题属于求解最值的问题。根据经验,我们可以猜想这道题可能是用 \(dp\) ,贪心或者二分等算法求解。可以发现,答案一定单调不增的,下面给出不严谨证明:
因为将值 \(x\) 改成值 \(y\) 以后,原本值 \(x\) 贡献的答案并不会改变;此时值 \(x\) 反而可能与值 \(y\) 产生更小的答案,即使从最坏的情况考虑,答案也一定不会增加。
由此,我们可以想出一个时间复杂度为 \(O(n ^ 2)\) 的乱搞做法:每次操作都扫描一遍数组,将值 \(x\) 改成值 \(y\) 。此时值 \(y\) 中包含了原来的值 \(x\) ,直接求出修改后值 \(y\) 两两间隔的最小值,并与原来的答案比较即可。
遗憾的是,此题的数据并没有 \(n, m \leq 1000\) 的情况。根据笔者珍贵的考场经验,直接对数组进行修改会得到 \(0\) 分的好成绩!因此,我们考虑其他做法。显然 \(dp\) 和贪心不可行,我们考虑使用技巧来优化修改值 \(x\) 和求解答案的过程。
每次修改以后查询答案,我们只需要查询 \(y\) 出现过的所有下标。因此,我们可以将所有值为 \(y\) 的下标都存在一个集合中。将值 \(x\) 修改成值 \(y\) ,相当于把 \(x\) 的下标集合合并到 \(y\) 的下标集合。发散思维,得到本题的正解做法:启发式合并 。估算时间复杂度,单次修改 \(O(logn)\) ,显然可以跑进 \(1\) 秒。
每次合并后必须要在 \(O(logn)\) 的时间复杂度内查找出集合 \(y\) 中与集合 \(x\) 中的某个数最接近的值。因此,我们选择用 set
数组存储,直接二分查找即可。因为值域太大,我们还需要进行 离散化 处理。set[i]
表示离散化后排名为 i
的数在数组中出现的所有位置。
设 id[i]
表示排名为 \(i\) 的数对应的集合下标。每次启发式合并两个不同的集合 \(x, y\) 时,不妨设 \(x\) 的元素个数较少,否则直接交换 id[x]
和 id[y]
。此时我们需要更新答案:可能与当前的下标 \(v\) 更新答案的下标只有最接近 \(v\) 的前后两个下标。我们考虑使用 lower_bound
,因为下标不会重复,所以我们可以直接 lower_bound
求出它后面最接近它的下标。lower_bound
的前一个元素就是 \(v\) 前面最接近 \(v\) 的下标,分别更新答案即可。更新完答案,记得将集合 \(x\) 加入集合 \(y\) ,并清空集合 \(x\) 中的元素。
最后,对 set
进行 lower_bound
一定要使用 s.lower_bound(val)
而非 lower_bound(s.begin(), s.end(), val)
。否则会得到光荣的 \(TLE\ 48\) 分。
#include <cstdio>
#include <set>
#include <algorithm>
using namespace std;
const int maxn = 3e5 + 5;
const int inf = 2147483647;
int n, m, ans;
int cnt, tot;
int a[maxn], b[maxn], idx[maxn];
int x[maxn], y[maxn], id[maxn];
set<int> s[maxn];
int read()
{
int res = 0, flag = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
flag = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res * flag;
}
void write(int x)
{
if (x < 0)
{
putchar('-');
x = -x;
}
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
void merge(int x, int y)
{
if (x == y)
return;
if (s[id[x]].size() > s[id[y]].size())
swap(id[x], id[y]);
set<int>::iterator i, j;
for (i = s[id[x]].begin(); i != s[id[x]].end(); i++)
{
j = s[id[y]].lower_bound(*i);
if (j != s[id[y]].end())
ans = min(ans, *j - *i);
if (j != s[id[y]].begin())
{
j--;
ans = min(ans, *i - *j);
}
s[id[y]].insert(*i);
}
s[id[x]].clear();
}
int main()
{
ans = inf;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
b[++cnt] = a[i] = read();
for (int i = 1; i <= m; i++)
{
b[++cnt] = x[i] = read();
b[++cnt] = y[i] = read();
}
sort(b + 1, b + cnt + 1);
tot = unique(b + 1, b + cnt + 1) - b - 1;
for (int i = 1; i <= cnt; i++)
id[i] = i;
for (int i = 1; i <= n; i++)
{
a[i] = lower_bound(b + 1, b + tot + 1, a[i]) - b;
s[a[i]].insert(i);
if (idx[a[i]])
ans = min(ans, i - idx[a[i]]);
idx[a[i]] = i;
}
for (int i = 1; i <= m; i++)
{
x[i] = lower_bound(b + 1, b + tot + 1, x[i]) - b;
y[i] = lower_bound(b + 1, b + tot + 1, y[i]) - b;
merge(x[i], y[i]);
write(ans), puts("");
}
return 0;
}