[BZOJ5043]密码破译

试题描述

小Q发明了一个新的加密算法,对于一个长度为n的非负整数序列a_1,a_2,...,a_n,他会随机选择一个非负整数k,

将每个数都异或上k得到b_1,b_2,...,b_n,即b_i=a_i xor k。不幸的是,健忘的小Q睡了一觉之后就把密钥k忘得
一干二净了,不过他隐约记得a_1+a_2+...+a_n的值为m,你能帮他找到一个可行的密钥吗

输入

第一行包含两个整数n,m(1<=n<=100000,0<=m<2^{60}),分别表示序列的长度以及加密前所有数的和。
第二行包含n个整数b_1,b_2,...,b_n(0<=b_i<2^{60}),表示加密后的序列。

输出

输出一个非负整数k,若无解输出-1,若有多组解,输出最小的k。

输入示例

3 5
1 2 3

输出示例

1

数据规模及约定

见“输入

题解

首先这道题肯定要按位处理,我们先预处理出数组 cnt[i] 表示对于所有 n 个数,第 i 位二进制中 1 的个数。

然后 dp,这题的 dp 自认为比较妙(可能做题太少了没见过类似的),设 f(i, j) 表示已经确定的 k 的第 i 位及更高位,现在与 m 相差 j * 2i-1 的最小的 k(注意,这里的 m 指的是第 i-1 位及更高位的 m,及 m = m' >> i-1,m' 表示输入给的 m)。

然后我们确定第 i-1 位的 k 到底填 1 还是 0,于是产生了两种转移:

填 1:f(i, j) -> f(i-1, j * 2 - cnt[i-1] + (m >> i-1 & 1))

填 0:f(i, j) -> f(i-1, j * 2 - (n - cnt[i-1]) + (m >> i-1 & 1))

以上两行的 m 都是输入给的 m,(m >> i-1 & 1) 表示 m 的二进制的第 i-1 位的值。

最后答案就是 f(0, 0) 了。

注意以上状态 f(i, j) 中 j 有可能是负数,并且我们不需要关心 |j| > n 的 j(想一想,为什么),所以可以在第二维状态上加一个 n。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define LL long long

LL read() {
	LL x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 100010
#define maxlog 60
#define ool (1ll << 60)

int n, cnt[maxn];
LL m, f[maxlog+1][maxn<<1];

bool inrange(int x) { return 0 <= x && x < (maxn << 1); }
void up(LL& a, LL b) { a = min(a, b); return ; }

int main() {
	n = read(); m = read();
	for(int i = 1; i <= n; i++) {
		LL x = read();
		for(int j = 0; j < maxlog; j++, x >>= 1) cnt[j] += x & 1;
	}
	
	for(int i = 0; i <= maxlog; i++)
		for(int j = 0; j < (maxn << 1); j++) f[i][j] = ool;
	f[60][maxn] = 0;
	for(int i = 60; i; i--)
		for(int j = 0; j < (maxn << 1); j++) if(f[i][j] < ool) {
			int to = (j - maxn) * 2 - cnt[i-1] + (m >> i - 1 & 1) + maxn, tt = to;
			if(inrange(to)) up(f[i-1][to], f[i][j]);
			to = (j - maxn) * 2 - (n - cnt[i-1]) + (m >> i - 1 & 1) + maxn;
			if(inrange(to)) up(f[i-1][to], f[i][j] | (1ll << i - 1));
//			if(i < 5 && maxn - 100 <= j && j <= maxn + 100) printf("%d %d: (%d %d)%lld\n", i, j, cnt[i-1], tt, f[i][j]);
		}
	
	printf("%lld\n", f[0][maxn] < ool ? f[0][maxn] : -1ll);
	
	return 0;
}