血崩多次……
原来KD树在二维以上的查询方式和二维不一样,我照葫芦画瓢一直是错的,改了才过
具体的,每次只查询当前划分的维度的距离才是正确的,每次递归到下一个维度时需要更改查询的维度。
感觉说的很不清晰,看代码吧
#include<cstdio> #include<cstring> #include<cstdlib> #include<queue> #include<stack> #include<algorithm> using namespace std; const long long N=500005,INF=210000000000000ll; long long n,k,m,t,Sort_Tag,root; priority_queue<pair<long long,long long> > PQ; stack<long long> St; struct KDTree{ long long d[5],mx[5],mn[5],l,r,id; friend bool operator<(KDTree A,KDTree B){return A.d[Sort_Tag]<B.d[Sort_Tag];} }tree[N],Temp; void Pushup(long long rt){ if(tree[rt].l){ for(long long i=0;i<k;i++){ tree[rt].mx[i]=max(tree[rt].mx[i],tree[tree[rt].l].mx[i]); tree[rt].mn[i]=min(tree[rt].mn[i],tree[tree[rt].l].mn[i]); } } if(tree[rt].r){ for(long long i=0;i<k;i++){ tree[rt].mx[i]=max(tree[rt].mx[i],tree[tree[rt].r].mx[i]); tree[rt].mn[i]=min(tree[rt].mn[i],tree[tree[rt].r].mn[i]); } } } long long Build(long long l,long long r,long long D){ Sort_Tag=D; long long mid=l+r>>1; nth_element(tree+l,tree+mid,tree+r+1); for(long long i=0;i<k;i++)tree[mid].mn[i]=tree[mid].mx[i]=tree[mid].d[i]; tree[mid].l=tree[mid].r=0; if(l<mid)tree[mid].l=Build(l,mid-1,D+1==k?0:D+1); if(r>mid)tree[mid].r=Build(mid+1,r,D+1==k?0:D+1); Pushup(mid); return mid; } long long Sqr(long long x){return x*x;} long long Dist(KDTree A,KDTree B){long long ans=0;for(long long i=0;i<k;i++)ans+=Sqr(A.d[i]-B.d[i]);return ans;} long long MnDist(KDTree A,KDTree B){long long ans=0;for(long long i=0;i<k;i++)ans+=min(Sqr(A.d[i]-B.mn[i]),Sqr(A.d[i]-B.mx[i]));return ans;} void Solve(long long rt,long long D){ long long L=tree[rt].l,R=tree[rt].r; if(Temp.d[D]>=tree[rt].d[D])swap(L,R); if(L)Solve(L,D+1==k?0:D+1); long long Dis=Dist(Temp,tree[rt]); if(Dis<PQ.top().first)PQ.pop(),PQ.push(make_pair(Dis,rt)); if(Sqr(tree[rt].d[D]-Temp.d[D])<PQ.top().first && R)Solve(R,D+1==k?0:D+1); } void Init_Solve_Out(){ while(!PQ.empty())PQ.pop(); while(!St.empty())St.pop(); for(long long i=1;i<=n;i++){ for(long long j=0;j<k;j++)scanf("%lld",&tree[i].d[j]); tree[i].id=i; } root=Build(1ll,n,0ll); scanf("%lld",&t); while(t--){ for(long long i=0;i<k;i++)scanf("%lld",&Temp.d[i]); scanf("%lld",&m); for(long long i=1;i<=m;i++)PQ.push(make_pair(INF,0)); Solve(root,0); printf("the closest %lld points are:\n",m); while(!PQ.empty())St.push(PQ.top().second),PQ.pop(); while(!St.empty()){ for(long long i=0;i<k-1;i++)printf("%lld ",tree[St.top()].d[i]); printf("%lld\n",tree[St.top()].d[k-1]); St.pop(); } } } int main(){ freopen("3053.in","r",stdin); freopen("3053.out","w",stdout); while(~scanf("%lld %lld",&n,&k))Init_Solve_Out(); return 0; }