2013-team4/code/kd-tree

从 Trac 迁移的文章

这是从旧校内 Wiki 迁移的文章,可能存在一些样式问题,您可以向 memset0 反馈。

原文章内容如下:

{{{
KD-Tree, O(NlogN)时间建树, O(sqrt(N)+k)查询最近点/最远点
用静态数组模拟实现,动态内存分配的实现待补
注意:这是查询K个最近点的模板,查询最远点的和这个有很大不同,,不能直接套用这个模板解决。
}}}
{{{
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN=50000+10;
int N, M, K, IDX;
struct Node {
    int val, idx;
    Node(int _v=0, int _i=0):val(_v),idx(_i){}
    bool operator<(const Node& t) const{
        return val<t.val;
    }
};

struct Point {
    int xy[5];
    bool operator<(const Point& t) const {
        return xy[IDX]<t.xy[IDX];
    }
};

std::priority_queue<Node> Q;
Point P[MAXN], KD[MAXN<<2], ans[MAXN], ask;
bool son[MAXN<<2];

inline int sqr(int x) {return x*x;}

void build(int rt, int l, int r, int dep) {
    if (l>r) return; son[rt]=true;
    int mid=(l+r)>>1; IDX=dep%K;
    std::nth_element(P+l, P+mid, P+r+1);
    KD[rt]=P[mid];
    build(rt*2, l, mid-1, dep+1);
    build(rt*2+1, mid+1, r, dep+1);
}

void query(int rt, int dep) {
    if (!son[rt]) return;
    int now=dep%K, ls=rt*2, rs=rt*2+1, need=0;
    Node tmp=Node(0, rt);
    for (int i=0; i<K; i++) tmp.val+=sqr(ask.xy[i]-KD[rt].xy[i]);
    if (ask.xy[now]>=KD[rt].xy[now]) std::swap(ls, rs);
    query(ls, dep+1);
    if ((int)Q.size()<M) Q.push(tmp), need=1;
    else {
        if (tmp.val<Q.top().val) Q.pop(), Q.push(tmp);
        if (sqr(ask.xy[now]-KD[rt].xy[now])<Q.top().val) need=1;
    }
    if (need) query(rs, dep+1);
}

int main() {
    while (scanf("%d%d", &N, &K)!=EOF) {
        memset(son, 0, sizeof(son));
        for (int i=1; i<=N; i++)
            for (int j=0; j<K; j++) scanf("%d", &P[i].xy[j]);
        build(1, 1, N, 0);
        int T; scanf("%d", &T);
        while (T--) {
            for (int i=0; i<K; i++) scanf("%d", &ask.xy[i]);
            scanf("%d", &M); query(1, 0);
            for (int i=1; i<=M; i++) ans[i]=KD[Q.top().idx], Q.pop();
            printf("the closest %d points are:\n", M);
            for (int i=M; i>=1; i--) {
                for (int j=0; j<K-1; j++) 
                    printf("%d ", ans[i].xy[j]);
                printf("%d\n", ans[i].xy[K-1]);
            }
        }
    }
    return 0;
}
}}}
KD-Tree, O(NlogN)时间建树, O(sqrt(N)+k)查询最近点/最远点
用静态数组模拟实现,动态内存分配的实现待补
注意:这是查询K个最近点的模板,查询最远点的和这个有很大不同,,不能直接套用这个模板解决。
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN=50000+10;
int N, M, K, IDX;
struct Node {
    int val, idx;
    Node(int _v=0, int _i=0):val(_v),idx(_i){}
    bool operator<(const Node& t) const{
        return val<t.val;
    }
};
struct Point {
    int xy[5];
    bool operator<(const Point& t) const {
        return xy[IDX]<t.xy[IDX];
    }
};
std::priority_queue<Node> Q;
Point P[MAXN], KD[MAXN<<2], ans[MAXN], ask;
bool son[MAXN<<2];
inline int sqr(int x) {return x*x;}
void build(int rt, int l, int r, int dep) {
    if (l>r) return; son[rt]=true;
    int mid=(l+r)>>1; IDX=dep%K;
    std::nth_element(P+l, P+mid, P+r+1);
    KD[rt]=P[mid];
    build(rt*2, l, mid-1, dep+1);
    build(rt*2+1, mid+1, r, dep+1);
}
void query(int rt, int dep) {
    if (!son[rt]) return;
    int now=dep%K, ls=rt*2, rs=rt*2+1, need=0;
    Node tmp=Node(0, rt);
    for (int i=0; i<K; i++) tmp.val+=sqr(ask.xy[i]-KD[rt].xy[i]);
    if (ask.xy[now]>=KD[rt].xy[now]) std::swap(ls, rs);
    query(ls, dep+1);
    if ((int)Q.size()<M) Q.push(tmp), need=1;
    else {
        if (tmp.val<Q.top().val) Q.pop(), Q.push(tmp);
        if (sqr(ask.xy[now]-KD[rt].xy[now])<Q.top().val) need=1;
    }
    if (need) query(rs, dep+1);
}
int main() {
    while (scanf("%d%d", &N, &K)!=EOF) {
        memset(son, 0, sizeof(son));
        for (int i=1; i<=N; i++)
            for (int j=0; j<K; j++) scanf("%d", &P[i].xy[j]);
        build(1, 1, N, 0);
        int T; scanf("%d", &T);
        while (T--) {
            for (int i=0; i<K; i++) scanf("%d", &ask.xy[i]);
            scanf("%d", &M); query(1, 0);
            for (int i=1; i<=M; i++) ans[i]=KD[Q.top().idx], Q.pop();
            printf("the closest %d points are:\n", M);
            for (int i=M; i>=1; i--) {
                for (int j=0; j<K-1; j++) 
                    printf("%d ", ans[i].xy[j]);
                printf("%d\n", ans[i].xy[K-1]);
            }
        }
    }
    return 0;
}