001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jsh2/Work/openimaj/target/checkout/core/core/src/main/jtemp/org/openimaj/util/tree/#T#KDTree.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034 package org.openimaj.util.tree; 035 036import gnu.trove.list.array.TIntArrayList; 037import gnu.trove.procedure.TIntObjectProcedure; 038import gnu.trove.procedure.TObjectFloatProcedure; 039import jal.objects.BinaryPredicate; 040import jal.objects.Sorting; 041 042import java.util.ArrayDeque; 043import java.util.ArrayList; 044import java.util.Arrays; 045import java.util.Deque; 046import java.util.List; 047 048import org.openimaj.util.array.ArrayUtils; 049import org.openimaj.util.array.IntArrayView; 050import org.openimaj.util.pair.*; 051import org.openimaj.util.queue.BoundedPriorityQueue; 052 053import cern.jet.random.Uniform; 054import cern.jet.random.engine.MersenneTwister; 055 056/** 057 * Immutable KD-Tree implementation for short[] data. Allows various 058 * tree-construction strategies to be applied through the 059 * {@link SplitChooser}. Supports efficient range, radius and 060 * nearest-neighbour search for relatively low dimensional spaces. 061 * 062 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 063 */ 064public class ShortKDTree { 065 /** 066 * Interface for describing how a branch in the KD-Tree should be created 067 * 068 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 069 * 070 */ 071 public static interface SplitChooser { 072 /** 073 * Choose the dimension and discriminant on which to split the data. 074 * 075 * @param pnts 076 * the raw data 077 * @param inds 078 * the indices of the data under consideration 079 * @param depth 080 * the depth of the current data in the tree 081 * @param minBounds 082 * the minimum bounds 083 * @param maxBounds 084 * the maximum bounds 085 * @return the dimension and discriminant, or null iff this is a leaf 086 * (containing all the points in inds). 087 */ 088 public IntShortPair chooseSplit(final short[][] pnts, final IntArrayView inds, int depth, short[] minBounds, 089 short[] maxBounds); 090 } 091 092 /** 093 * Basic median split. Each dimension will be split at it's median value in 094 * turn. 095 * 096 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 097 * 098 */ 099 public static class BasicMedianSplit implements SplitChooser { 100 int maxBucketSize = 24; 101 102 /** 103 * Construct with the default maximum number of items per bucket 104 */ 105 public BasicMedianSplit() { 106 } 107 108 /** 109 * Construct with the given maximum number of items per bucket 110 * 111 * @param maxBucketSize 112 * maximum number of items per bucket 113 */ 114 public BasicMedianSplit(int maxBucketSize) { 115 this.maxBucketSize = maxBucketSize; 116 } 117 118 @Override 119 public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds, 120 short[] maxBounds) 121 { 122 if (inds.size() < maxBucketSize) 123 return null; 124 125 final int dim = depth % pnts[0].length; 126 127 final short[] data = new short[inds.size()]; 128 for (int i = 0; i < data.length; i++) 129 data[i] = pnts[inds.getFast(i)][dim]; 130 final short median = ArrayUtils.quickSelect(data, data.length / 2); 131 132 return IntShortPair.pair(dim, median); 133 } 134 } 135 136 /** 137 * Best-bin-first median splitting. Best bin is chosen from the dimension 138 * with the largest variance (computed from all the data at the node). 139 * 140 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 141 * 142 */ 143 public static class BBFMedianSplit implements SplitChooser { 144 int maxBucketSize = 24; 145 146 /** 147 * Construct with the default maximum number of items per bucket 148 */ 149 public BBFMedianSplit() { 150 } 151 152 /** 153 * Construct with the given maximum number of items per bucket 154 * 155 * @param maxBucketSize 156 * maximum number of items per bucket 157 */ 158 public BBFMedianSplit(int maxBucketSize) { 159 this.maxBucketSize = maxBucketSize; 160 } 161 162 @Override 163 public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds, 164 short[] maxBounds) 165 { 166 if (inds.size() < maxBucketSize) 167 return null; 168 169 // Find mean & variance of each dimension. 170 final int D = pnts[0].length; 171 final float[] sumX = new float[D]; 172 final float[] sumXX = new float[D]; 173 final int count = inds.size(); 174 175 for (int n = 0; n < count; ++n) { 176 for (int d = 0; d < D; ++d) { 177 final int i = inds.getFast(n); 178 179 sumX[d] += pnts[i][d]; 180 sumXX[d] += (pnts[i][d] * pnts[i][d]); 181 } 182 } 183 184 int dim = 0; 185 float maxVar = (sumXX[0] - ((float) 1 / count) * sumX[0] * sumX[0]) / (count - 1); 186 187 for (int d = 1; d < D; ++d) { 188 final float var = (sumXX[d] - ((float) 1 / count) * sumX[d] * sumX[d]) / (count - 1); 189 if (var > maxVar) { 190 maxVar = var; 191 dim = d; 192 } 193 } 194 195 if (maxVar == 0) 196 return null; 197 198 final short[] data = new short[inds.size()]; 199 for (int i = 0; i < data.length; i++) 200 data[i] = pnts[inds.getFast(i)][dim]; 201 final short median = ArrayUtils.quickSelect(data, data.length / 2); 202 203 return IntShortPair.pair(dim, median); 204 } 205 } 206 207 /** 208 * Approximate best-bin-first median splitting. Best bin is chosen from the 209 * dimension with the largest range. 210 * 211 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 212 */ 213 public static class ApproximateBBFMedianSplit implements SplitChooser { 214 int maxBucketSize = 24; 215 216 /** 217 * Construct with the default maximum number of items per bucket 218 */ 219 public ApproximateBBFMedianSplit() { 220 } 221 222 /** 223 * Construct with the given maximum number of items per bucket 224 * 225 * @param maxBucketSize 226 * maximum number of items per bucket 227 */ 228 public ApproximateBBFMedianSplit(int maxBucketSize) { 229 this.maxBucketSize = maxBucketSize; 230 } 231 232 @Override 233 public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds, 234 short[] maxBounds) 235 { 236 if (inds.size() < maxBucketSize) 237 return null; 238 239 // find biggest range of each dimension 240 int dim = 0; 241 float maxVar = maxBounds[0] - minBounds[0]; 242 for (int d = 1; d < pnts[0].length; ++d) { 243 final float var = maxBounds[d] - minBounds[d]; 244 if (var > maxVar) { 245 maxVar = var; 246 dim = d; 247 } 248 } 249 250 if (maxVar == 0) 251 return null; 252 253 final short[] data = new short[inds.size()]; 254 for (int i = 0; i < data.length; i++) 255 data[i] = pnts[inds.getFast(i)][dim]; 256 final short median = ArrayUtils.quickSelect(data, data.length / 2); 257 258 return IntShortPair.pair(dim, median); 259 } 260 } 261 262 /** 263 * Randomised best-bin-first splitting strategy: 264 * <ul> 265 * <li>Nodes with less than a set number of items become leafs. 266 * <li>Otherwise: 267 * <ul> 268 * <li>a sample of the data is taken and the variance across each dimension 269 * is computed. 270 * <li>a dimension is chosen randomly from the dimensions with the higest 271 * variance. 272 * <li>the mean (computed from the variance sample) is taken as the split 273 * point. 274 * </ul> 275 * </ul> 276 * 277 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 278 * 279 */ 280 public static class RandomisedBBFMeanSplit implements SplitChooser { 281 /** 282 * Maximum number of items in a leaf. 283 */ 284 private static final int maxLeafSize = 14; 285 286 /** 287 * Maximum number of points of variance estimation; all points used if 288 * <=0. 289 */ 290 private static final int varianceMaxPoints = 128; 291 292 /** 293 * Number of dimensions to consider when randomly selecting one with a 294 * big variance. 295 */ 296 private static final int randomMaxDims = 5; 297 298 /** 299 * The random source 300 */ 301 private Uniform rng; 302 303 /** 304 * Construct with the default values of 14 points per leaf (max), 128 305 * samples for computing variance, and the 5 most varying dimensions 306 * randomly selected. A new {@link MersenneTwister} is created as the 307 * source for random numbers. 308 */ 309 public RandomisedBBFMeanSplit() { 310 this.rng = new Uniform(new MersenneTwister()); 311 } 312 313 /** 314 * Construct with the default values of 14 points per leaf (max), 128 315 * samples for computing variance, and the 5 most varying dimensions 316 * randomly selected. A new {@link MersenneTwister} is created as the 317 * source for random numbers. 318 * 319 * @param uniform 320 * the random number source 321 */ 322 public RandomisedBBFMeanSplit(Uniform uniform) { 323 this.rng = uniform; 324 } 325 326 /** 327 * Construct with the given values. 328 * 329 * @param maxLeafSize 330 * Maximum number of items in a leaf. 331 * @param varianceMaxPoints 332 * Maximum number of points of variance estimation; all 333 * points used if <=0. 334 * @param randomMaxDims 335 * Number of dimensions to consider when randomly selecting 336 * one with a big variance. 337 * @param uniform 338 * the random number source 339 */ 340 public RandomisedBBFMeanSplit(int maxLeafSize, int varianceMaxPoints, int randomMaxDims, Uniform uniform) 341 { 342 this.rng = uniform; 343 } 344 345 @Override 346 public IntShortPair chooseSplit(final short[][] pnts, final IntArrayView inds, int depth, short[] minBounds, 347 short[] maxBounds) 348 { 349 if (inds.size() < maxLeafSize) 350 return null; 351 352 final int D = pnts[0].length; 353 354 // Find mean & variance of each dimension. 355 final float[] sumX = new float[D]; 356 final float[] sumXX = new float[D]; 357 358 final int count = Math.min(inds.size(), varianceMaxPoints); 359 for (int n = 0; n < count; ++n) { 360 for (int d = 0; d < D; ++d) { 361 final int i = inds.getFast(n); 362 363 sumX[d] += pnts[i][d]; 364 sumXX[d] += (pnts[i][d] * pnts[i][d]); 365 } 366 } 367 368 final FloatIntPair[] varPerDim = new FloatIntPair[D]; 369 for (int d = 0; d < D; ++d) { 370 varPerDim[d] = new FloatIntPair(); 371 varPerDim[d].second = d; 372 373 if (count <= 1) 374 varPerDim[d].first = 0; 375 else 376 varPerDim[d].first = (sumXX[d] - ((float) 1 / count) * sumX[d] * sumX[d]) / (count - 1); 377 } 378 379 // Partial sort makes a BIG difference to the build time. 380 final int nrand = Math.min(randomMaxDims, D); 381 Sorting.partial_sort(varPerDim, 0, nrand, varPerDim.length, new BinaryPredicate() { 382 @Override 383 public boolean apply(Object arg0, Object arg1) { 384 final FloatIntPair p1 = (FloatIntPair) arg0; 385 final FloatIntPair p2 = (FloatIntPair) arg1; 386 387 if (p1.first > p2.first) 388 return true; 389 if (p2.first > p1.first) 390 return false; 391 return (p1.second > p2.second); 392 } 393 }); 394 395 final int randd = varPerDim[rng.nextIntFromTo(0, nrand - 1)].second; 396 397 return new IntShortPair(randd, (short)(sumX[randd] / count)); 398 } 399 } 400 401 /** 402 * An internal node of the KDTree 403 */ 404 public static class KDTreeNode { 405 /** 406 * Node to the left 407 */ 408 public KDTreeNode left; 409 410 /** 411 * Node to the right 412 */ 413 public KDTreeNode right; 414 415 /** 416 * Splitting value 417 */ 418 public short discriminant; 419 420 /** 421 * Splitting dimension 422 */ 423 public int discriminantDimension; 424 425 /** 426 * The minimum bounds of this node 427 */ 428 public short[] minBounds; 429 430 /** 431 * The maximum bounds of this node 432 */ 433 public short[] maxBounds; 434 435 /** 436 * The leaf only holds the indices of the original data 437 */ 438 public int[] indices; 439 440 /** 441 * Construct a new node with the given data 442 * 443 * @param pnts 444 * the data for the node and its children 445 * @param inds 446 * a list of indices that point to the relevant parts of the 447 * pnts array that should be used 448 * @param split 449 * the {@link SplitChooser} to use when constructing 450 * the tree 451 */ 452 public KDTreeNode(final short[][] pnts, IntArrayView inds, SplitChooser split) { 453 this(pnts, inds, split, 0, null, true); 454 } 455 456 private KDTreeNode(final short[][] pnts, IntArrayView inds, SplitChooser split, int depth, 457 KDTreeNode parent, boolean isLeft) 458 { 459 // set the bounds of this node 460 if (parent == null) { 461 this.minBounds = new short[pnts[0].length]; 462 this.maxBounds = new short[pnts[0].length]; 463 464 Arrays.fill(minBounds, Short.MAX_VALUE); 465 Arrays.fill(maxBounds, (short)(-Short.MAX_VALUE)); 466 467 for (int y = 0; y < pnts.length; y++) { 468 for (int x = 0; x < pnts[0].length; x++) { 469 if (minBounds[x] > pnts[y][x]) 470 minBounds[x] = pnts[y][x]; 471 if (maxBounds[x] < pnts[y][x]) 472 maxBounds[x] = pnts[y][x]; 473 } 474 } 475 Arrays.fill(minBounds, (short)(-Short.MAX_VALUE)); 476 Arrays.fill(maxBounds, Short.MAX_VALUE); 477 } else { 478 this.minBounds = parent.minBounds.clone(); 479 this.maxBounds = parent.maxBounds.clone(); 480 481 if (isLeft) { 482 maxBounds[parent.discriminantDimension] = parent.discriminant; 483 } else { 484 minBounds[parent.discriminantDimension] = parent.discriminant; 485 } 486 } 487 488 // test to see where/if we should split 489 final IntShortPair spl = split.chooseSplit(pnts, inds, depth, minBounds, maxBounds); 490 491 if (spl == null) { 492 // this will be a leaf node 493 indices = inds.toArray(); 494 } else { 495 discriminantDimension = spl.first; 496 discriminant = spl.second; 497 498 // partially sort the inds so that all the data with 499 // data[discriminantDimension] < discriminant is on one side 500 final int N = inds.size(); 501 int l = 0; 502 int r = N; 503 while (l != r) { 504 if (pnts[inds.getFast(l)][discriminantDimension] < discriminant) 505 l++; 506 else { 507 r--; 508 final int t = inds.getFast(l); 509 inds.setFast(l, inds.getFast(r)); 510 inds.setFast(r, t); 511 } 512 } 513 514 // If either partition is empty then the are vectors identical. 515 // Choose the midpoint to keep the O(nlog(n)) performance. 516 if (l == 0 || l == N) { 517 // l = N / 2; 518 this.discriminant = 0; 519 this.discriminantDimension = 0; 520 this.indices = inds.toArray(); 521 } else { 522 // create the child nodes 523 left = new KDTreeNode(pnts, inds.subView(0, l), split, depth + 1, this, true); 524 right = new KDTreeNode(pnts, inds.subView(l, N), split, depth + 1, this, false); 525 } 526 } 527 } 528 529 /** 530 * Test to see if this node is a leaf node (i.e. 531 * <code>{@link #indices} != null</code>) 532 * 533 * @return true if this is a leaf node; false otherwise 534 */ 535 public boolean isLeaf() { 536 return indices != null; 537 } 538 539 private final boolean inRange(short value, short min, short max) { 540 return (value >= min) && (value <= max); 541 } 542 543 /** 544 * Test whether the bounds of this node are disjoint from the 545 * hyperrectangle described by the given bounds. 546 * 547 * @param lowerExtreme 548 * the lower bounds of the hyperrectangle 549 * @param upperExtreme 550 * the upper bounds of the hyperrectangle 551 * @return true if disjoint; false otherwise 552 */ 553 public boolean isDisjointFrom(short[] lowerExtreme, short[] upperExtreme) { 554 for (int i = 0; i < lowerExtreme.length; i++) { 555 if (!(inRange(minBounds[i], lowerExtreme[i], upperExtreme[i]) || inRange(lowerExtreme[i], minBounds[i], 556 maxBounds[i]))) 557 return true; 558 } 559 560 return false; 561 } 562 563 /** 564 * Test whether the bounds of this node are fully contained by the 565 * hyperrectangle described by the given bounds. 566 * 567 * @param lowerExtreme 568 * the lower bounds of the hyperrectangle 569 * @param upperExtreme 570 * the upper bounds of the hyperrectangle 571 * @return true if fully contained; false otherwise 572 */ 573 public boolean isContainedBy(short[] lowerExtreme, short[] upperExtreme) { 574 for (int i = 0; i < lowerExtreme.length; i++) { 575 if (minBounds[i] < lowerExtreme[i] || maxBounds[i] > upperExtreme[i]) 576 return false; 577 } 578 return true; 579 } 580 } 581 582 /** The tree roots */ 583 public final KDTreeNode root; 584 585 /** The underlying data array */ 586 public final short[][] data; 587 588 /** 589 * Construct with the given data and default splitting strategy ({@link BBFMedianSplit}) 590 * 591 * @param data 592 * the data 593 */ 594 public ShortKDTree(short[][] data) { 595 this.data = data; 596 this.root = new KDTreeNode(data, new IntArrayView(ArrayUtils.range(0, data.length - 1)), new BBFMedianSplit()); 597 } 598 599 /** 600 * Construct with the given data and splitting strategy 601 * 602 * @param data 603 * the data 604 * @param split 605 * the splitting strategy 606 */ 607 public ShortKDTree(short[][] data, SplitChooser split) { 608 this.data = data; 609 this.root = new KDTreeNode(data, new IntArrayView(ArrayUtils.range(0, data.length - 1)), split); 610 } 611 612 /** 613 * Search the tree for all points contained within the hyperrectangle 614 * defined by the given upper and lower extremes. 615 * 616 * @param lowerExtreme 617 * the lower extreme of the hyperrectangle 618 * @param upperExtreme 619 * the upper extreme of the hyperrectangle 620 * @return the points within the given bounds 621 */ 622 public short[][] coordinateRangeSearch(short[] lowerExtreme, short[] upperExtreme) { 623 final List<short[]> results = new ArrayList<short[]>(); 624 625 rangeSearch(lowerExtreme, upperExtreme, new TIntObjectProcedure<short[]>() { 626 @Override 627 public boolean execute(int a, short[] b) { 628 results.add(b); 629 630 return true; 631 } 632 }); 633 634 return results.toArray(new short[results.size()][]); 635 } 636 637 /** 638 * Search the tree for all points contained within the hyperrectangle 639 * defined by the given upper and lower extremes. 640 * 641 * @param lowerExtreme 642 * the lower extreme of the hyperrectangle 643 * @param upperExtreme 644 * the upper extreme of the hyperrectangle 645 * @return the points within the given bounds 646 */ 647 public int[] indexRangeSearch(short[] lowerExtreme, short[] upperExtreme) { 648 final TIntArrayList results = new TIntArrayList(); 649 650 rangeSearch(lowerExtreme, upperExtreme, new TIntObjectProcedure<short[]>() { 651 @Override 652 public boolean execute(int a, short[] b) { 653 results.add(a); 654 655 return true; 656 } 657 }); 658 659 return results.toArray(); 660 } 661 662 /** 663 * Search the tree for the indexes of all points contained within the 664 * hypersphere defined by the given centre and radius. 665 * 666 * @param centre 667 * the centre point 668 * @param radius 669 * the radius 670 * @return the points within the given bounds 671 */ 672 public int[] indexRadiusSearch(short[] centre, short radius) { 673 final TIntArrayList results = new TIntArrayList(); 674 675 this.radiusSearch(centre, radius, new TIntObjectProcedure<short[]>() { 676 @Override 677 public boolean execute(int a, short[] b) { 678 results.add(a); 679 680 return true; 681 } 682 }); 683 684 return results.toArray(); 685 } 686 687 /** 688 * Find all the points within the given radius of the given point. 689 * Internally this works by finding the points in the hyper-square 690 * encompassing the hyper-circle and then filtering. Each valid point that 691 * is found is reported to the given processor together with its index. 692 * <p> 693 * The search can be stopped early by returning false from the 694 * {@link TIntObjectProcedure#execute(int, Object)} method. 695 * 696 * @param centre 697 * the centre point 698 * @param radius 699 * the radius 700 * @param proc 701 * the process 702 */ 703 public void radiusSearch(final short[] centre, short radius, final TIntObjectProcedure<short[]> proc) 704 { 705 final short[] lower = centre.clone(); 706 final short[] upper = centre.clone(); 707 708 for (int i = 0; i < centre.length; i++) { 709 lower[i] -= radius; 710 upper[i] += radius; 711 } 712 713 final float radSq = radius * radius; 714 rangeSearch(lower, upper, new TIntObjectProcedure<short[]>() { 715 @Override 716 public boolean execute(int idx, short[] point) { 717 final float d = distance(centre, point); 718 if (d <= radSq) 719 return proc.execute(idx, point); 720 721 return true; 722 } 723 }); 724 } 725 726 /** 727 * Search the tree for all points contained within the hyperrectangle 728 * defined by the given upper and lower extremes. Each valid point that is 729 * found is reported to the given processor together with its index in the 730 * original data. 731 * <p> 732 * The search can be stopped early by returning false from the 733 * {@link TIntObjectProcedure#execute(int, Object)} method. 734 * 735 * @param lowerExtreme 736 * the lower extreme of the hyperrectangle 737 * @param upperExtreme 738 * the upper extreme of the hyperrectangle 739 * @param proc 740 * the processor 741 */ 742 public void rangeSearch(short[] lowerExtreme, short[] upperExtreme, TIntObjectProcedure<short[]> proc) { 743 final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>(); 744 745 if (root == null) 746 return; 747 748 stack.push(root); 749 750 while (!stack.isEmpty()) { 751 final KDTreeNode tmpNode = stack.pop(); 752 753 if (tmpNode.isLeaf()) { 754 for (int i = 0; i < tmpNode.indices.length; i++) { 755 final int idx = tmpNode.indices[i]; 756 final short[] vec = data[idx]; 757 if (isContained(vec, lowerExtreme, upperExtreme)) 758 if (!proc.execute(idx, vec)) 759 return; 760 } 761 } else { 762 if (tmpNode.isDisjointFrom(lowerExtreme, upperExtreme)) { 763 continue; 764 } 765 766 if (tmpNode.isContainedBy(lowerExtreme, upperExtreme)) { 767 reportSubtree(tmpNode, proc); 768 } else { 769 if (tmpNode.left != null) 770 stack.push(tmpNode.left); 771 if (tmpNode.right != null) 772 stack.push(tmpNode.right); 773 } 774 } 775 } 776 } 777 778 /** 779 * Determines if a point is contained within a given k-dimensional bounding 780 * box. 781 */ 782 private final boolean isContained(short[] point, short[] lower, short[] upper) 783 { 784 for (int i = 0; i < point.length; i++) { 785 if (point[i] < lower[i] || point[i] > upper[i]) 786 return false; 787 } 788 789 return true; 790 } 791 792 /** 793 * Report all the child items of the given subtree to the process 794 * 795 * @param root 796 * the root of the subtree 797 * @param proc 798 * the process to apply 799 */ 800 private void reportSubtree(KDTreeNode root, TIntObjectProcedure<short[]> proc) { 801 final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>(); 802 stack.push(root); 803 804 while (!stack.isEmpty()) { 805 final KDTreeNode tmpNode = stack.pop(); 806 807 if (tmpNode.isLeaf()) { 808 for (int i = 0; i < tmpNode.indices.length; i++) { 809 final int idx = tmpNode.indices[i]; 810 if (!proc.execute(idx, data[idx])) 811 return; 812 } 813 } else { 814 if (tmpNode.left != null) 815 stack.push(tmpNode.left); 816 if (tmpNode.right != null) 817 stack.push(tmpNode.right); 818 } 819 } 820 } 821 822 /** 823 * Nearest-neighbour search 824 * 825 * @param qu 826 * the query point 827 * @param n 828 * the number of neighbours to find 829 * @return the indices and distances 830 */ 831 public List<IntFloatPair> nearestNeighbours(final short[] qu, int n) { 832 final BoundedPriorityQueue<IntFloatPair> queue = new BoundedPriorityQueue<IntFloatPair>(n, 833 IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 834 835 searchSubTree(qu, root, queue); 836 837 return queue.toOrderedListDestructive(); 838 } 839 840 /** 841 * Nearest-neighbour search 842 * 843 * @param qu 844 * the query point 845 * @return the indices and distances 846 */ 847 public IntFloatPair nearestNeighbour(final short[] qu) { 848 final BoundedPriorityQueue<IntFloatPair> queue = new BoundedPriorityQueue<IntFloatPair>(1, 849 IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 850 851 searchSubTree(qu, root, queue); 852 853 return queue.peek(); 854 } 855 856 /** 857 * Find all the points within the given radius of the given point. 858 * Internally this works by finding the points in the hyper-square 859 * encompassing the hyper-circle and then filtering. 860 * 861 * @param centre 862 * the centre point 863 * @param radius 864 * the radius 865 * @return the points 866 */ 867 public short[][] coordinateRadiusSearch(short[] centre, short radius) { 868 final List<short[]> radiusList = new ArrayList<short[]>(); 869 870 coordinateRadiusSearch(centre, radius, new TObjectFloatProcedure<short[]>() { 871 @Override 872 public boolean execute(short[] a, float b) { 873 radiusList.add(a); 874 return true; 875 } 876 }); 877 878 return radiusList.toArray(new short[radiusList.size()][]); 879 } 880 881 /** 882 * Find all the points within the given radius of the given point. 883 * Internally this works by finding the points in the hyper-square 884 * encompassing the hyper-circle and then filtering. Each valid point that 885 * is found is reported to the given processor together with its distance 886 * from the centre. 887 * <p> 888 * The search can be stopped early by returning false from the 889 * {@link TIntObjectProcedure#execute(int, Object)} method. 890 * 891 * @param centre 892 * the centre point 893 * @param radius 894 * the radius 895 * @param proc 896 * the process 897 */ 898 public void coordinateRadiusSearch(final short[] centre, short radius, final TObjectFloatProcedure<short[]> proc) 899 { 900 final short[] lower = centre.clone(); 901 final short[] upper = centre.clone(); 902 903 for (int i = 0; i < centre.length; i++) { 904 lower[i] -= radius; 905 upper[i] += radius; 906 } 907 908 final float radSq = radius * radius; 909 rangeSearch(lower, upper, new TIntObjectProcedure<short[]>() { 910 @Override 911 public boolean execute(int idx, short[] point) { 912 final float d = distance(centre, point); 913 if (d <= radSq) 914 return proc.execute(point, d); 915 916 return true; 917 } 918 }); 919 } 920 921 private void searchSubTree(final short[] qu, KDTreeNode cur, BoundedPriorityQueue<IntFloatPair> queue) { 922 final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>(); 923 while (!cur.isLeaf()) { 924 stack.push(cur); 925 926 final float diff = qu[cur.discriminantDimension] - cur.discriminant; 927 928 if (diff < 0) { 929 cur = cur.left; 930 } else { 931 cur = cur.right; 932 } 933 } 934 935 for (int i = 0; i < cur.indices.length; i++) { 936 final int idx = cur.indices[i]; 937 final short[] vec = data[idx]; 938 final float dist = distance(qu, vec); 939 queue.add(new IntFloatPair(idx, dist)); 940 } 941 942 while (!stack.isEmpty()) { 943 cur = stack.pop(); 944 final float diff = qu[cur.discriminantDimension] - cur.discriminant; 945 946 final float worstDist = queue.peekTail().second; 947 948 if (diff * diff <= worstDist || !queue.isFull()) { 949 // need to search subtree 950 if (diff < 0) { 951 searchSubTree(qu, cur.right, queue); 952 } else { 953 searchSubTree(qu, cur.left, queue); 954 } 955 } 956 } 957 } 958 959 private float distance(short[] qu, short[] vec) { 960 float d = 0; 961 for (int i = 0; i < qu.length; i++) 962 d += (qu[i] - vec[i]) * (qu[i] - vec[i]); 963 return d; 964 } 965 966 /** 967 * Find all the indices seperated by leaves 968 * @return all the leaves 969 */ 970 public List<int[]> leafIndices() { 971 List<int[]> leafInds = new ArrayList<int[]>(); 972 Deque<KDTreeNode> nodes = new ArrayDeque<KDTreeNode>(); 973 nodes.push(root); 974 while(nodes.size()!=0){ 975 KDTreeNode node = nodes.pop(); 976 if(node.isLeaf()){ 977 leafInds.add(node.indices); 978 } 979 else{ 980 nodes.push(node.left); 981 nodes.push(node.right); 982 } 983 } 984 985 return leafInds; 986 } 987}