2013-team4/code/kd-tree
从 Trac 迁移的文章
这是从旧校内 Wiki 迁移的文章,可能存在一些样式问题,您可以向 memset0 反馈。
原文章内容如下:
{{{
KD-Tree, O(NlogN)时间建树, O(sqrt(N)+k)查询最近点/最远点
用静态数组模拟实现,动态内存分配的实现待补
注意:这是查询K个最近点的模板,查询最远点的和这个有很大不同,,不能直接套用这个模板解决。
}}}
{{{
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN=50000+10;
int N, M, K, IDX;
struct Node {
int val, idx;
Node(int _v=0, int _i=0):val(_v),idx(_i){}
bool operator<(const Node& t) const{
return val<t.val;
}
};
struct Point {
int xy[5];
bool operator<(const Point& t) const {
return xy[IDX]<t.xy[IDX];
}
};
std::priority_queue<Node> Q;
Point P[MAXN], KD[MAXN<<2], ans[MAXN], ask;
bool son[MAXN<<2];
inline int sqr(int x) {return x*x;}
void build(int rt, int l, int r, int dep) {
if (l>r) return; son[rt]=true;
int mid=(l+r)>>1; IDX=dep%K;
std::nth_element(P+l, P+mid, P+r+1);
KD[rt]=P[mid];
build(rt*2, l, mid-1, dep+1);
build(rt*2+1, mid+1, r, dep+1);
}
void query(int rt, int dep) {
if (!son[rt]) return;
int now=dep%K, ls=rt*2, rs=rt*2+1, need=0;
Node tmp=Node(0, rt);
for (int i=0; i<K; i++) tmp.val+=sqr(ask.xy[i]-KD[rt].xy[i]);
if (ask.xy[now]>=KD[rt].xy[now]) std::swap(ls, rs);
query(ls, dep+1);
if ((int)Q.size()<M) Q.push(tmp), need=1;
else {
if (tmp.val<Q.top().val) Q.pop(), Q.push(tmp);
if (sqr(ask.xy[now]-KD[rt].xy[now])<Q.top().val) need=1;
}
if (need) query(rs, dep+1);
}
int main() {
while (scanf("%d%d", &N, &K)!=EOF) {
memset(son, 0, sizeof(son));
for (int i=1; i<=N; i++)
for (int j=0; j<K; j++) scanf("%d", &P[i].xy[j]);
build(1, 1, N, 0);
int T; scanf("%d", &T);
while (T--) {
for (int i=0; i<K; i++) scanf("%d", &ask.xy[i]);
scanf("%d", &M); query(1, 0);
for (int i=1; i<=M; i++) ans[i]=KD[Q.top().idx], Q.pop();
printf("the closest %d points are:\n", M);
for (int i=M; i>=1; i--) {
for (int j=0; j<K-1; j++)
printf("%d ", ans[i].xy[j]);
printf("%d\n", ans[i].xy[K-1]);
}
}
}
return 0;
}
}}}
KD-Tree, O(NlogN)时间建树, O(sqrt(N)+k)查询最近点/最远点
用静态数组模拟实现,动态内存分配的实现待补
注意:这是查询K个最近点的模板,查询最远点的和这个有很大不同,,不能直接套用这个模板解决。
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN=50000+10;
int N, M, K, IDX;
struct Node {
int val, idx;
Node(int _v=0, int _i=0):val(_v),idx(_i){}
bool operator<(const Node& t) const{
return val<t.val;
}
};
struct Point {
int xy[5];
bool operator<(const Point& t) const {
return xy[IDX]<t.xy[IDX];
}
};
std::priority_queue<Node> Q;
Point P[MAXN], KD[MAXN<<2], ans[MAXN], ask;
bool son[MAXN<<2];
inline int sqr(int x) {return x*x;}
void build(int rt, int l, int r, int dep) {
if (l>r) return; son[rt]=true;
int mid=(l+r)>>1; IDX=dep%K;
std::nth_element(P+l, P+mid, P+r+1);
KD[rt]=P[mid];
build(rt*2, l, mid-1, dep+1);
build(rt*2+1, mid+1, r, dep+1);
}
void query(int rt, int dep) {
if (!son[rt]) return;
int now=dep%K, ls=rt*2, rs=rt*2+1, need=0;
Node tmp=Node(0, rt);
for (int i=0; i<K; i++) tmp.val+=sqr(ask.xy[i]-KD[rt].xy[i]);
if (ask.xy[now]>=KD[rt].xy[now]) std::swap(ls, rs);
query(ls, dep+1);
if ((int)Q.size()<M) Q.push(tmp), need=1;
else {
if (tmp.val<Q.top().val) Q.pop(), Q.push(tmp);
if (sqr(ask.xy[now]-KD[rt].xy[now])<Q.top().val) need=1;
}
if (need) query(rs, dep+1);
}
int main() {
while (scanf("%d%d", &N, &K)!=EOF) {
memset(son, 0, sizeof(son));
for (int i=1; i<=N; i++)
for (int j=0; j<K; j++) scanf("%d", &P[i].xy[j]);
build(1, 1, N, 0);
int T; scanf("%d", &T);
while (T--) {
for (int i=0; i<K; i++) scanf("%d", &ask.xy[i]);
scanf("%d", &M); query(1, 0);
for (int i=1; i<=M; i++) ans[i]=KD[Q.top().idx], Q.pop();
printf("the closest %d points are:\n", M);
for (int i=M; i>=1; i--) {
for (int j=0; j<K-1; j++)
printf("%d ", ans[i].xy[j]);
printf("%d\n", ans[i].xy[K-1]);
}
}
}
return 0;
}