2018-team4-modules-KDTree
从 Trac 迁移的文章
这是从旧校内 Wiki 迁移的文章,可能存在一些样式问题,您可以向 memset0 反馈。
原文章内容如下:
{{{
const int dem = 2;
struct Node{
Node *ch[2];
int pos[2],id;
int minn[2],maxx[2];
inline void update(){
minn[0] = min(ch[0]->minn[0],ch[1]->minn[0],pos[0]);
maxx[0] = max(ch[0]->maxx[0],ch[1]->maxx[0],pos[0]);
minn[1] = min(ch[0]->minn[1],ch[1]->minn[1],pos[1]);
maxx[1] = max(ch[0]->maxx[1],ch[1]->maxx[1],pos[1]);
}
}*null,*root;
Node T[maxn],asd;
inline void init(){
null = &asd;null->ch[0] = null->ch[1] = null;
null->minn[0] = null->minn[1] = 0x3f3f3f3f;
null->maxx[0] = null->maxx[1] = -0x3f3f3f3f;
}int nowd;
inline bool cmp(const Node &a,const Node &b){
return a.pos[nowd] < b.pos[nowd];
}
Node *build(int l,int r,int s){
if(l > r) return null;
int mid = (l+r) >> 1;
nowd = s % dem;
nth_element(T+l,T+mid,T+r+1,cmp);
Node *p = &T[mid];
p->ch[0] = build(l,mid-1,s+1);
p->ch[1] = build(mid+1,r,s+1);
p->update();return p;
}
ll ans_max,ans_min;
inline ll cat_abs(int x){return x < 0 ? -x : x;}
Node op;
inline ll get_max(Node *p){
ll ret = 0;
ret = max(ret,cat_abs(op.pos[0]-p->minn[0]) + cat_abs(op.pos[1]-p->minn[1]));
ret = max(ret,cat_abs(op.pos[0]-p->minn[0]) + cat_abs(op.pos[1]-p->maxx[1]));
ret = max(ret,cat_abs(op.pos[0]-p->maxx[0]) + cat_abs(op.pos[1]-p->minn[1]));
ret = max(ret,cat_abs(op.pos[0]-p->maxx[0]) + cat_abs(op.pos[1]-p->maxx[1]));
return ret;
}
inline ll get_min(Node *p){
ll ret = 0;
for(int d=0;d<dem;d++) if(op.pos[d]<p->minn[d]||op.pos[d]>p->maxx[d])
ret+=(op.pos[d]<p->minn[d])?p->minn[d]-op.pos[d]:op.pos[d]-p->maxx[d];
return ret;
}
//平面上距离最远点(曼哈顿距离)
void query_max(Node *p){
if(p == null) return;
ll dis = 0;
for(int d=0;d<dem;++d) dis += cat_abs(p->pos[d] - op.pos[d]);
if(p->id != op.id) ans_max = max(ans_max,dis);
if(get_max(p->ch[0]) < get_max(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(get_max(p->ch[0]) > ans_max) query_max(p->ch[0]);
if(get_max(p->ch[1]) > ans_max) query_max(p->ch[1]);
}
//平面上距离最近点(曼哈顿距离)
void query_min(Node *p){
if(p == null) return;
ll dis = 0;
for(int d=0;d<dem;++d) dis += cat_abs(p->pos[d] - op.pos[d]);
if(p->id != op.id) ans_min = min(ans_min,dis);
if(get_min(p->ch[0]) > get_min(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(get_min(p->ch[0]) < ans_min) query_min(p->ch[0]);
if(get_min(p->ch[1]) < ans_min) query_min(p->ch[1]);
}
//数据用data储存
//求距离一个点的k远点(欧几里得距离)
struct Data{
ll dis;ll id;
bool operator < (const Data &a)const{
if(dis != a.dis) return dis > a.dis;
return id < a.id;
}
Data(){}
Data(ll a,ll b){dis=a;id=b;}
};
priority_queue<Data>q;
inline ll sqr(ll x){return x*x;}
Node op;ll k;
inline ll md(Node *p){
ll ret = .0;
ret = max(ret,sqr(p->minn[0] - op.pos[0]) + sqr(p->minn[1] - op.pos[1]));
ret = max(ret,sqr(p->minn[0] - op.pos[0]) + sqr(p->maxx[1] - op.pos[1]));
ret = max(ret,sqr(p->maxx[0] - op.pos[0]) + sqr(p->minn[1] - op.pos[1]));
ret = max(ret,sqr(p->maxx[0] - op.pos[0]) + sqr(p->maxx[1] - op.pos[1]));
return ret;
}
void query(Node *p){
if(p == null) return;
ll dis = 0;
for(ll d=0;d<dem;++d) dis += sqr(op.pos[d] - p->pos[d]);
if(q.size() < k){
q.push(Data(dis,p->id));
}else if(Data(dis,p->id) < q.top()){
q.pop();q.push(Data(dis,p->id));
}
if(md(p->ch[0]) < md(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(p->ch[0] != null && ((q.size() < k) || (md(p->ch[0]) >= q.top().dis))) query(p->ch[0]);
if(p->ch[1] != null && ((q.size() < k) || (md(p->ch[1]) >= q.top().dis))) query(p->ch[1]);
}
//求二维平面矩阵内的元素和(X1,Y1),(X2,Y2)两点构成的矩阵
//需要维护p->sum = p->ch[0]->sum + p->ch[1]->sum + p->val;
int query(Node *p){
if(p == null) return 0;
if(p->minn[0] >= X1 && p->minn[1] >= Y1
&& p->maxx[0] <= X2 && p->maxx[1] <= Y2)
return p->sum;
else if(p->maxx[1] < Y1 || p->maxx[0] < X1
|| p->minn[0] > X2 || p->minn[1] > Y2)
return 0;
if( p->pos[0] >= X1 && p->pos[1] >= Y1
&& p->pos[0] <= X2 && p->pos[1] <= Y2)
return p->val + query(p->ch[0]) + query(p->ch[1]);
return query(p->ch[0]) + query(p->ch[1]);
}
// k远点调用实例
int main(){
init();
ll n;read(n);
for(ll i=1;i<=n;++i){
for(ll d=0;d<dem;++d){
read(T[i].pos[d]);
}
T[i].ch[0] = T[i].ch[1] = null;
T[i].id = i;T[i].update();
}
Node *root = build(1,n,1);
ll m;read(m);
while(m--){
for(ll d=0;d<dem;++d) read(op.pos[d]);
read(k);
while(!q.empty()) q.pop();
query(root);
printf("%lld\n",q.top().id);
}
getchar();getchar();
return 0;
}
}}}
const int dem = 2;
struct Node{
Node *ch[2];
int pos[2],id;
int minn[2],maxx[2];
inline void update(){
minn[0] = min(ch[0]->minn[0],ch[1]->minn[0],pos[0]);
maxx[0] = max(ch[0]->maxx[0],ch[1]->maxx[0],pos[0]);
minn[1] = min(ch[0]->minn[1],ch[1]->minn[1],pos[1]);
maxx[1] = max(ch[0]->maxx[1],ch[1]->maxx[1],pos[1]);
}
}*null,*root;
Node T[maxn],asd;
inline void init(){
null = &asd;null->ch[0] = null->ch[1] = null;
null->minn[0] = null->minn[1] = 0x3f3f3f3f;
null->maxx[0] = null->maxx[1] = -0x3f3f3f3f;
}int nowd;
inline bool cmp(const Node &a,const Node &b){
return a.pos[nowd] < b.pos[nowd];
}
Node *build(int l,int r,int s){
if(l > r) return null;
int mid = (l+r) >> 1;
nowd = s % dem;
nth_element(T+l,T+mid,T+r+1,cmp);
Node *p = &T[mid];
p->ch[0] = build(l,mid-1,s+1);
p->ch[1] = build(mid+1,r,s+1);
p->update();return p;
}
ll ans_max,ans_min;
inline ll cat_abs(int x){return x < 0 ? -x : x;}
Node op;
inline ll get_max(Node *p){
ll ret = 0;
ret = max(ret,cat_abs(op.pos[0]-p->minn[0]) + cat_abs(op.pos[1]-p->minn[1]));
ret = max(ret,cat_abs(op.pos[0]-p->minn[0]) + cat_abs(op.pos[1]-p->maxx[1]));
ret = max(ret,cat_abs(op.pos[0]-p->maxx[0]) + cat_abs(op.pos[1]-p->minn[1]));
ret = max(ret,cat_abs(op.pos[0]-p->maxx[0]) + cat_abs(op.pos[1]-p->maxx[1]));
return ret;
}
inline ll get_min(Node *p){
ll ret = 0;
for(int d=0;d<dem;d++) if(op.pos[d]<p->minn[d]||op.pos[d]>p->maxx[d])
ret+=(op.pos[d]<p->minn[d])?p->minn[d]-op.pos[d]:op.pos[d]-p->maxx[d];
return ret;
}
//平面上距离最远点(曼哈顿距离)
void query_max(Node *p){
if(p == null) return;
ll dis = 0;
for(int d=0;d<dem;++d) dis += cat_abs(p->pos[d] - op.pos[d]);
if(p->id != op.id) ans_max = max(ans_max,dis);
if(get_max(p->ch[0]) < get_max(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(get_max(p->ch[0]) > ans_max) query_max(p->ch[0]);
if(get_max(p->ch[1]) > ans_max) query_max(p->ch[1]);
}
//平面上距离最近点(曼哈顿距离)
void query_min(Node *p){
if(p == null) return;
ll dis = 0;
for(int d=0;d<dem;++d) dis += cat_abs(p->pos[d] - op.pos[d]);
if(p->id != op.id) ans_min = min(ans_min,dis);
if(get_min(p->ch[0]) > get_min(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(get_min(p->ch[0]) < ans_min) query_min(p->ch[0]);
if(get_min(p->ch[1]) < ans_min) query_min(p->ch[1]);
}
//数据用data储存
//求距离一个点的k远点(欧几里得距离)
struct Data{
ll dis;ll id;
bool operator < (const Data &a)const{
if(dis != a.dis) return dis > a.dis;
return id < a.id;
}
Data(){}
Data(ll a,ll b){dis=a;id=b;}
};
priority_queue<Data>q;
inline ll sqr(ll x){return x*x;}
Node op;ll k;
inline ll md(Node *p){
ll ret = .0;
ret = max(ret,sqr(p->minn[0] - op.pos[0]) + sqr(p->minn[1] - op.pos[1]));
ret = max(ret,sqr(p->minn[0] - op.pos[0]) + sqr(p->maxx[1] - op.pos[1]));
ret = max(ret,sqr(p->maxx[0] - op.pos[0]) + sqr(p->minn[1] - op.pos[1]));
ret = max(ret,sqr(p->maxx[0] - op.pos[0]) + sqr(p->maxx[1] - op.pos[1]));
return ret;
}
void query(Node *p){
if(p == null) return;
ll dis = 0;
for(ll d=0;d<dem;++d) dis += sqr(op.pos[d] - p->pos[d]);
if(q.size() < k){
q.push(Data(dis,p->id));
}else if(Data(dis,p->id) < q.top()){
q.pop();q.push(Data(dis,p->id));
}
if(md(p->ch[0]) < md(p->ch[1])) swap(p->ch[0],p->ch[1]);
if(p->ch[0] != null && ((q.size() < k) || (md(p->ch[0]) >= q.top().dis))) query(p->ch[0]);
if(p->ch[1] != null && ((q.size() < k) || (md(p->ch[1]) >= q.top().dis))) query(p->ch[1]);
}
//求二维平面矩阵内的元素和(X1,Y1),(X2,Y2)两点构成的矩阵
//需要维护p->sum = p->ch[0]->sum + p->ch[1]->sum + p->val;
int query(Node *p){
if(p == null) return 0;
if(p->minn[0] >= X1 && p->minn[1] >= Y1
&& p->maxx[0] <= X2 && p->maxx[1] <= Y2)
return p->sum;
else if(p->maxx[1] < Y1 || p->maxx[0] < X1
|| p->minn[0] > X2 || p->minn[1] > Y2)
return 0;
if( p->pos[0] >= X1 && p->pos[1] >= Y1
&& p->pos[0] <= X2 && p->pos[1] <= Y2)
return p->val + query(p->ch[0]) + query(p->ch[1]);
return query(p->ch[0]) + query(p->ch[1]);
}
// k远点调用实例
int main(){
init();
ll n;read(n);
for(ll i=1;i<=n;++i){
for(ll d=0;d<dem;++d){
read(T[i].pos[d]);
}
T[i].ch[0] = T[i].ch[1] = null;
T[i].id = i;T[i].update();
}
Node *root = build(1,n,1);
ll m;read(m);
while(m--){
for(ll d=0;d<dem;++d) read(op.pos[d]);
read(k);
while(!q.empty()) q.pop();
query(root);
printf("%lld\n",q.top().id);
}
getchar();getchar();
return 0;
}