给定长度分别为 \(N\), \(M\) 的序列 \(A\), \(B\) .

序列 \(L\) 的构造方式如下:

-- lists are 1-indexed --

procedure generate_list(A, B, x):

let n = length of A
let m = length of B
let L = an empty list

for i from 1 to min(n, m - x), inclusive:
for j from (i + x) to m, inclusive:
Append (A[i]*B[j]) to the end of L

return L

求出序列 \(L\) 中第 \(k\) 小的数字。

\(1\leq n,m\leq 2\times 10^5,1\leq x\leq m,1\leq |A_i|,|B_i|\leq 2\times 10^5,1\leq k\leq length(L)\)

关键是想到二分.

二分最后的答案 \(val\) .

从后往前枚举每个 \(a_i\) ,同时建立一个 bit ,维护当前 \(b\) 中数的范围 .

可以计算出有多少个数小于等于 \(val\) .

\(b\) 中每个数最多被插入依次,\(a\) 中每个数最多查询一次.

时间复杂度 : \(O(n\log m\log v)\)

空间复杂度 : \(O(n)\)

code

#include<bits/stdc++.h>
using namespace std;
const int add=2e5+10;
int n,m,X;
long long k;
int a[200010],b[200010];
int bit[400010];//从前往后
inline void upd(int i){
while(i){
bit[i]++;
i-=i&-i;
}
}
inline int qry(int i){
// cout<<i<<endl;
int res=0;
while(i<=400010){
res+=bit[i];
i+=i&-i;
}
return res;
}
inline long long get1(long long x,long long y){
// cout<<x<<","<<y<<endl;
if(x>=0)return 1ll*x/y;
else{
if((-x)%y!=0)return -1ll*((-x)/y)-1;
return -1ll*((-x)/y);
}
}
inline long long get2(long long x,long long y){
if(x>=0){
if(x%abs(y)!=0)return -1ll*(x/(-y));
return -1ll*(x/(-y));
}
else{
if(x%abs(y)!=0)return 1ll*((-x)/(-y))+1;
return 1ll*((-x)/(-y));
}
}
bool check(long long tmp){
memset(bit,0,sizeof(bit));
long long sum=0;
int cnt=0,lst=m;
for(int i=n-1;i>=0;i--){
for(int j=i+X;j<lst;j++){
upd(b[j]+add);
cnt++;
}
lst=min(lst,i+X);
// cout<<a[i]<<" "<<b[i+x]<<endl;
if(a[i]>0){
long long val=get1(tmp,a[i]);
// cout<<val<<endl;
val=min(val,200008ll);
val=max(val,-200008ll);
// cout<<val<<","<<cnt<<","<<qry(val+add+1)<<" ";
sum+=cnt-qry(val+add+1);
}
else{
long long val=get2(tmp,a[i]);
val=min(val,200008ll);
val=max(val,-200008ll);
// cout<<val<<endl;
// cout<<val<<","<<qry(val+add)<<" ";
sum+=qry(val+add);
}
}
// cout<<endl;
// cout<<tmp<<" "<<sum<<endl;
return sum>=k;
}
int main(){
freopen("kth.in","r",stdin);
freopen("kth.out","w",stdout);
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>m>>X>>k;
for(int i=0;i<n;i++)cin>>a[i];
for(int i=0;i<m;i++)cin>>b[i];
long long low=-40000000000,high=40000000000+1,ans=40000000000;
while(low<high){
long long mid=(low+high)>>1;
if(check(mid)){
ans=min(ans,mid);
high=mid;
}
else{
low=mid+1;
}
}
cout<<ans<<endl;
return 0;
}
/*inline? ll or int? size? min max?*/
/*
3 4 1 5
2 -3 1
3 1 -2 -1
*/