001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/core/core/src/main/jtemp/org/openimaj/util/tree/#T#KDTree.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034 package org.openimaj.util.tree;
035
036import gnu.trove.list.array.TIntArrayList;
037import gnu.trove.procedure.TIntObjectProcedure;
038import gnu.trove.procedure.TObjectFloatProcedure;
039import jal.objects.BinaryPredicate;
040import jal.objects.Sorting;
041
042import java.util.ArrayDeque;
043import java.util.ArrayList;
044import java.util.Arrays;
045import java.util.Deque;
046import java.util.List;
047
048import org.openimaj.util.array.ArrayUtils;
049import org.openimaj.util.array.IntArrayView;
050import org.openimaj.util.pair.*;
051import org.openimaj.util.queue.BoundedPriorityQueue;
052
053import cern.jet.random.Uniform;
054import cern.jet.random.engine.MersenneTwister;
055
056/**
057 * Immutable KD-Tree implementation for short[] data. Allows various
058 * tree-construction strategies to be applied through the
059 * {@link SplitChooser}. Supports efficient range, radius and
060 * nearest-neighbour search for relatively low dimensional spaces.
061 * 
062 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
063 */
064public class ShortKDTree {
065        /**
066         * Interface for describing how a branch in the KD-Tree should be created
067         * 
068         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
069         * 
070         */
071        public static interface SplitChooser {
072                /**
073                 * Choose the dimension and discriminant on which to split the data.
074                 * 
075                 * @param pnts
076                 *            the raw data
077                 * @param inds
078                 *            the indices of the data under consideration
079                 * @param depth
080                 *            the depth of the current data in the tree
081                 * @param minBounds
082                 *            the minimum bounds
083                 * @param maxBounds
084                 *            the maximum bounds
085                 * @return the dimension and discriminant, or null iff this is a leaf
086                 *         (containing all the points in inds).
087                 */
088                public IntShortPair chooseSplit(final short[][] pnts, final IntArrayView inds, int depth, short[] minBounds,
089                                short[] maxBounds);
090        }
091
092        /**
093         * Basic median split. Each dimension will be split at it's median value in
094         * turn.
095         * 
096         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
097         * 
098         */
099        public static class BasicMedianSplit implements SplitChooser {
100                int maxBucketSize = 24;
101
102                /**
103                 * Construct with the default maximum number of items per bucket
104                 */
105                public BasicMedianSplit() {
106                }
107
108                /**
109                 * Construct with the given maximum number of items per bucket
110                 * 
111                 * @param maxBucketSize
112                 *            maximum number of items per bucket
113                 */
114                public BasicMedianSplit(int maxBucketSize) {
115                        this.maxBucketSize = maxBucketSize;
116                }
117
118                @Override
119                public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds,
120                                short[] maxBounds)
121                {
122                        if (inds.size() < maxBucketSize)
123                                return null;
124
125                        final int dim = depth % pnts[0].length;
126
127                        final short[] data = new short[inds.size()];
128                        for (int i = 0; i < data.length; i++)
129                                data[i] = pnts[inds.getFast(i)][dim];
130                        final short median = ArrayUtils.quickSelect(data, data.length / 2);
131
132                        return IntShortPair.pair(dim, median);
133                }
134        }
135
136        /**
137         * Best-bin-first median splitting. Best bin is chosen from the dimension
138         * with the largest variance (computed from all the data at the node).
139         * 
140         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
141         * 
142         */
143        public static class BBFMedianSplit implements SplitChooser {
144                int maxBucketSize = 24;
145
146                /**
147                 * Construct with the default maximum number of items per bucket
148                 */
149                public BBFMedianSplit() {
150                }
151
152                /**
153                 * Construct with the given maximum number of items per bucket
154                 * 
155                 * @param maxBucketSize
156                 *            maximum number of items per bucket
157                 */
158                public BBFMedianSplit(int maxBucketSize) {
159                        this.maxBucketSize = maxBucketSize;
160                }
161
162                @Override
163                public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds,
164                                short[] maxBounds)
165                {
166                        if (inds.size() < maxBucketSize)
167                                return null;
168
169                        // Find mean & variance of each dimension.
170                        final int D = pnts[0].length;
171                        final float[] sumX = new float[D];
172                        final float[] sumXX = new float[D];
173                        final int count = inds.size();
174
175                        for (int n = 0; n < count; ++n) {
176                                for (int d = 0; d < D; ++d) {
177                                        final int i = inds.getFast(n);
178
179                                        sumX[d] += pnts[i][d];
180                                        sumXX[d] += (pnts[i][d] * pnts[i][d]);
181                                }
182                        }
183
184                        int dim = 0;
185                        float maxVar = (sumXX[0] - ((float) 1 / count) * sumX[0] * sumX[0]) / (count - 1);
186
187                        for (int d = 1; d < D; ++d) {
188                                final float var = (sumXX[d] - ((float) 1 / count) * sumX[d] * sumX[d]) / (count - 1);
189                                if (var > maxVar) {
190                                        maxVar = var;
191                                        dim = d;
192                                }
193                        }
194
195                        if (maxVar == 0)
196                                return null;
197
198                        final short[] data = new short[inds.size()];
199                        for (int i = 0; i < data.length; i++)
200                                data[i] = pnts[inds.getFast(i)][dim];
201                        final short median = ArrayUtils.quickSelect(data, data.length / 2);
202
203                        return IntShortPair.pair(dim, median);
204                }
205        }
206
207        /**
208         * Approximate best-bin-first median splitting. Best bin is chosen from the
209         * dimension with the largest range.
210         * 
211         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
212         */
213        public static class ApproximateBBFMedianSplit implements SplitChooser {
214                int maxBucketSize = 24;
215
216                /**
217                 * Construct with the default maximum number of items per bucket
218                 */
219                public ApproximateBBFMedianSplit() {
220                }
221
222                /**
223                 * Construct with the given maximum number of items per bucket
224                 * 
225                 * @param maxBucketSize
226                 *            maximum number of items per bucket
227                 */
228                public ApproximateBBFMedianSplit(int maxBucketSize) {
229                        this.maxBucketSize = maxBucketSize;
230                }
231
232                @Override
233                public IntShortPair chooseSplit(short[][] pnts, IntArrayView inds, int depth, short[] minBounds,
234                                short[] maxBounds)
235                {
236                        if (inds.size() < maxBucketSize)
237                                return null;
238
239                        // find biggest range of each dimension
240                        int dim = 0;
241                        float maxVar = maxBounds[0] - minBounds[0];
242                        for (int d = 1; d < pnts[0].length; ++d) {
243                                final float var = maxBounds[d] - minBounds[d];
244                                if (var > maxVar) {
245                                        maxVar = var;
246                                        dim = d;
247                                }
248                        }
249
250                        if (maxVar == 0)
251                                return null;
252
253                        final short[] data = new short[inds.size()];
254                        for (int i = 0; i < data.length; i++)
255                                data[i] = pnts[inds.getFast(i)][dim];
256                        final short median = ArrayUtils.quickSelect(data, data.length / 2);
257
258                        return IntShortPair.pair(dim, median);
259                }
260        }
261
262        /**
263         * Randomised best-bin-first splitting strategy:
264         * <ul>
265         * <li>Nodes with less than a set number of items become leafs.
266         * <li>Otherwise:
267         * <ul>
268         * <li>a sample of the data is taken and the variance across each dimension
269         * is computed.
270         * <li>a dimension is chosen randomly from the dimensions with the higest
271         * variance.
272         * <li>the mean (computed from the variance sample) is taken as the split
273         * point.
274         * </ul>
275         * </ul>
276         * 
277         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
278         * 
279         */
280        public static class RandomisedBBFMeanSplit implements SplitChooser {
281                /**
282                 * Maximum number of items in a leaf.
283                 */
284                private static final int maxLeafSize = 14;
285
286                /**
287                 * Maximum number of points of variance estimation; all points used if
288                 * <=0.
289                 */
290                private static final int varianceMaxPoints = 128;
291
292                /**
293                 * Number of dimensions to consider when randomly selecting one with a
294                 * big variance.
295                 */
296                private static final int randomMaxDims = 5;
297
298                /**
299                 * The random source
300                 */
301                private Uniform rng;
302
303                /**
304                 * Construct with the default values of 14 points per leaf (max), 128
305                 * samples for computing variance, and the 5 most varying dimensions
306                 * randomly selected. A new {@link MersenneTwister} is created as the
307                 * source for random numbers.
308                 */
309                public RandomisedBBFMeanSplit() {
310                        this.rng = new Uniform(new MersenneTwister());
311                }
312
313                /**
314                 * Construct with the default values of 14 points per leaf (max), 128
315                 * samples for computing variance, and the 5 most varying dimensions
316                 * randomly selected. A new {@link MersenneTwister} is created as the
317                 * source for random numbers.
318                 * 
319                 * @param uniform
320                 *            the random number source
321                 */
322                public RandomisedBBFMeanSplit(Uniform uniform) {
323                        this.rng = uniform;
324                }
325
326                /**
327                 * Construct with the given values.
328                 * 
329                 * @param maxLeafSize
330                 *            Maximum number of items in a leaf.
331                 * @param varianceMaxPoints
332                 *            Maximum number of points of variance estimation; all
333                 *            points used if <=0.
334                 * @param randomMaxDims
335                 *            Number of dimensions to consider when randomly selecting
336                 *            one with a big variance.
337                 * @param uniform
338                 *            the random number source
339                 */
340                public RandomisedBBFMeanSplit(int maxLeafSize, int varianceMaxPoints, int randomMaxDims, Uniform uniform)
341                {
342                        this.rng = uniform;
343                }
344
345                @Override
346                public IntShortPair chooseSplit(final short[][] pnts, final IntArrayView inds, int depth, short[] minBounds,
347                                short[] maxBounds)
348                {
349                        if (inds.size() < maxLeafSize)
350                                return null;
351
352                        final int D = pnts[0].length;
353
354                        // Find mean & variance of each dimension.
355                        final float[] sumX = new float[D];
356                        final float[] sumXX = new float[D];
357
358                        final int count = Math.min(inds.size(), varianceMaxPoints);
359                        for (int n = 0; n < count; ++n) {
360                                for (int d = 0; d < D; ++d) {
361                                        final int i = inds.getFast(n);
362
363                                        sumX[d] += pnts[i][d];
364                                        sumXX[d] += (pnts[i][d] * pnts[i][d]);
365                                }
366                        }
367
368                        final FloatIntPair[] varPerDim = new FloatIntPair[D];
369                        for (int d = 0; d < D; ++d) {
370                                varPerDim[d] = new FloatIntPair();
371                                varPerDim[d].second = d;
372
373                                if (count <= 1)
374                                        varPerDim[d].first = 0;
375                                else
376                                        varPerDim[d].first = (sumXX[d] - ((float) 1 / count) * sumX[d] * sumX[d]) / (count - 1);
377                        }
378
379                        // Partial sort makes a BIG difference to the build time.
380                        final int nrand = Math.min(randomMaxDims, D);
381                        Sorting.partial_sort(varPerDim, 0, nrand, varPerDim.length, new BinaryPredicate() {
382                                @Override
383                                public boolean apply(Object arg0, Object arg1) {
384                                        final FloatIntPair p1 = (FloatIntPair) arg0;
385                                        final FloatIntPair p2 = (FloatIntPair) arg1;
386
387                                        if (p1.first > p2.first)
388                                                return true;
389                                        if (p2.first > p1.first)
390                                                return false;
391                                        return (p1.second > p2.second);
392                                }
393                        });
394
395                        final int randd = varPerDim[rng.nextIntFromTo(0, nrand - 1)].second;
396
397                        return new IntShortPair(randd, (short)(sumX[randd] / count));
398                }
399        }
400
401        /**
402         * An internal node of the KDTree
403         */
404        public static class KDTreeNode {
405                /**
406                 * Node to the left
407                 */
408                public KDTreeNode left;
409
410                /**
411                 * Node to the right
412                 */
413                public KDTreeNode right;
414
415                /**
416                 * Splitting value
417                 */
418                public short discriminant;
419
420                /**
421                 * Splitting dimension
422                 */
423                public int discriminantDimension;
424
425                /**
426                 * The minimum bounds of this node
427                 */
428                public short[] minBounds;
429
430                /**
431                 * The maximum bounds of this node
432                 */
433                public short[] maxBounds;
434
435                /**
436                 * The leaf only holds the indices of the original data
437                 */
438                public int[] indices;
439
440                /**
441                 * Construct a new node with the given data
442                 * 
443                 * @param pnts
444                 *            the data for the node and its children
445                 * @param inds
446                 *            a list of indices that point to the relevant parts of the
447                 *            pnts array that should be used
448                 * @param split
449                 *            the {@link SplitChooser} to use when constructing
450                 *            the tree
451                 */
452                public KDTreeNode(final short[][] pnts, IntArrayView inds, SplitChooser split) {
453                        this(pnts, inds, split, 0, null, true);
454                }
455
456                private KDTreeNode(final short[][] pnts, IntArrayView inds, SplitChooser split, int depth,
457                                KDTreeNode parent, boolean isLeft)
458                {
459                        // set the bounds of this node
460                        if (parent == null) {
461                                this.minBounds = new short[pnts[0].length];
462                                this.maxBounds = new short[pnts[0].length];
463
464                                Arrays.fill(minBounds, Short.MAX_VALUE);
465                                Arrays.fill(maxBounds, (short)(-Short.MAX_VALUE));
466
467                                for (int y = 0; y < pnts.length; y++) {
468                                        for (int x = 0; x < pnts[0].length; x++) {
469                                                if (minBounds[x] > pnts[y][x])
470                                                        minBounds[x] = pnts[y][x];
471                                                if (maxBounds[x] < pnts[y][x])
472                                                        maxBounds[x] = pnts[y][x];
473                                        }
474                                }
475                                Arrays.fill(minBounds, (short)(-Short.MAX_VALUE));
476                                Arrays.fill(maxBounds, Short.MAX_VALUE);
477                        } else {
478                                this.minBounds = parent.minBounds.clone();
479                                this.maxBounds = parent.maxBounds.clone();
480
481                                if (isLeft) {
482                                        maxBounds[parent.discriminantDimension] = parent.discriminant;
483                                } else {
484                                        minBounds[parent.discriminantDimension] = parent.discriminant;
485                                }
486                        }
487
488                        // test to see where/if we should split
489                        final IntShortPair spl = split.chooseSplit(pnts, inds, depth, minBounds, maxBounds);
490
491                        if (spl == null) {
492                                // this will be a leaf node
493                                indices = inds.toArray();
494                        } else {
495                                discriminantDimension = spl.first;
496                                discriminant = spl.second;
497
498                                // partially sort the inds so that all the data with
499                                // data[discriminantDimension] < discriminant is on one side
500                                final int N = inds.size();
501                                int l = 0;
502                                int r = N;
503                                while (l != r) {
504                                        if (pnts[inds.getFast(l)][discriminantDimension] < discriminant)
505                                                l++;
506                                        else {
507                                                r--;
508                                                final int t = inds.getFast(l);
509                                                inds.setFast(l, inds.getFast(r));
510                                                inds.setFast(r, t);
511                                        }
512                                }
513
514                                // If either partition is empty then the are vectors identical.
515                                // Choose the midpoint to keep the O(nlog(n)) performance.
516                                if (l == 0 || l == N) {
517                                        // l = N / 2;
518                                        this.discriminant = 0;
519                                        this.discriminantDimension = 0;
520                                        this.indices = inds.toArray();
521                                } else {
522                                        // create the child nodes
523                                        left = new KDTreeNode(pnts, inds.subView(0, l), split, depth + 1, this, true);
524                                        right = new KDTreeNode(pnts, inds.subView(l, N), split, depth + 1, this, false);
525                                }
526                        }
527                }
528
529                /**
530                 * Test to see if this node is a leaf node (i.e.
531                 * <code>{@link #indices} != null</code>)
532                 * 
533                 * @return true if this is a leaf node; false otherwise
534                 */
535                public boolean isLeaf() {
536                        return indices != null;
537                }
538
539                private final boolean inRange(short value, short min, short max) {
540                        return (value >= min) && (value <= max);
541                }
542
543                /**
544                 * Test whether the bounds of this node are disjoint from the
545                 * hyperrectangle described by the given bounds.
546                 * 
547                 * @param lowerExtreme
548                 *            the lower bounds of the hyperrectangle
549                 * @param upperExtreme
550                 *            the upper bounds of the hyperrectangle
551                 * @return true if disjoint; false otherwise
552                 */
553                public boolean isDisjointFrom(short[] lowerExtreme, short[] upperExtreme) {
554                        for (int i = 0; i < lowerExtreme.length; i++) {
555                                if (!(inRange(minBounds[i], lowerExtreme[i], upperExtreme[i]) || inRange(lowerExtreme[i], minBounds[i],
556                                                maxBounds[i])))
557                                        return true;
558                        }
559
560                        return false;
561                }
562
563                /**
564                 * Test whether the bounds of this node are fully contained by the
565                 * hyperrectangle described by the given bounds.
566                 * 
567                 * @param lowerExtreme
568                 *            the lower bounds of the hyperrectangle
569                 * @param upperExtreme
570                 *            the upper bounds of the hyperrectangle
571                 * @return true if fully contained; false otherwise
572                 */
573                public boolean isContainedBy(short[] lowerExtreme, short[] upperExtreme) {
574                        for (int i = 0; i < lowerExtreme.length; i++) {
575                                if (minBounds[i] < lowerExtreme[i] || maxBounds[i] > upperExtreme[i])
576                                        return false;
577                        }
578                        return true;
579                }
580        }
581
582        /** The tree roots */
583        public final KDTreeNode root;
584
585        /** The underlying data array */
586        public final short[][] data;
587
588        /**
589         * Construct with the given data and default splitting strategy ({@link BBFMedianSplit})
590         * 
591         * @param data
592         *            the data
593         */
594        public ShortKDTree(short[][] data) {
595                this.data = data;
596                this.root = new KDTreeNode(data, new IntArrayView(ArrayUtils.range(0, data.length - 1)), new BBFMedianSplit());
597        }
598
599        /**
600         * Construct with the given data and splitting strategy
601         * 
602         * @param data
603         *            the data
604         * @param split
605         *            the splitting strategy
606         */
607        public ShortKDTree(short[][] data, SplitChooser split) {
608                this.data = data;
609                this.root = new KDTreeNode(data, new IntArrayView(ArrayUtils.range(0, data.length - 1)), split);
610        }
611
612        /**
613         * Search the tree for all points contained within the hyperrectangle
614         * defined by the given upper and lower extremes.
615         * 
616         * @param lowerExtreme
617         *            the lower extreme of the hyperrectangle
618         * @param upperExtreme
619         *            the upper extreme of the hyperrectangle
620         * @return the points within the given bounds
621         */
622        public short[][] coordinateRangeSearch(short[] lowerExtreme, short[] upperExtreme) {
623                final List<short[]> results = new ArrayList<short[]>();
624
625                rangeSearch(lowerExtreme, upperExtreme, new TIntObjectProcedure<short[]>() {
626                        @Override
627                        public boolean execute(int a, short[] b) {
628                                results.add(b);
629
630                                return true;
631                        }
632                });
633
634                return results.toArray(new short[results.size()][]);
635        }
636
637        /**
638         * Search the tree for all points contained within the hyperrectangle
639         * defined by the given upper and lower extremes.
640         * 
641         * @param lowerExtreme
642         *            the lower extreme of the hyperrectangle
643         * @param upperExtreme
644         *            the upper extreme of the hyperrectangle
645         * @return the points within the given bounds
646         */
647        public int[] indexRangeSearch(short[] lowerExtreme, short[] upperExtreme) {
648                final TIntArrayList results = new TIntArrayList();
649
650                rangeSearch(lowerExtreme, upperExtreme, new TIntObjectProcedure<short[]>() {
651                        @Override
652                        public boolean execute(int a, short[] b) {
653                                results.add(a);
654
655                                return true;
656                        }
657                });
658
659                return results.toArray();
660        }
661
662        /**
663         * Search the tree for the indexes of all points contained within the
664         * hypersphere defined by the given centre and radius.
665         * 
666         * @param centre
667         *            the centre point
668         * @param radius
669         *            the radius
670         * @return the points within the given bounds
671         */
672        public int[] indexRadiusSearch(short[] centre, short radius) {
673                final TIntArrayList results = new TIntArrayList();
674
675                this.radiusSearch(centre, radius, new TIntObjectProcedure<short[]>() {
676                        @Override
677                        public boolean execute(int a, short[] b) {
678                                results.add(a);
679
680                                return true;
681                        }
682                });
683
684                return results.toArray();
685        }
686
687        /**
688         * Find all the points within the given radius of the given point.
689         * Internally this works by finding the points in the hyper-square
690         * encompassing the hyper-circle and then filtering. Each valid point that
691         * is found is reported to the given processor together with its index.
692         * <p>
693         * The search can be stopped early by returning false from the
694         * {@link TIntObjectProcedure#execute(int, Object)} method.
695         * 
696         * @param centre
697         *            the centre point
698         * @param radius
699         *            the radius
700         * @param proc
701         *            the process
702         */
703        public void radiusSearch(final short[] centre, short radius, final TIntObjectProcedure<short[]> proc)
704        {
705                final short[] lower = centre.clone();
706                final short[] upper = centre.clone();
707
708                for (int i = 0; i < centre.length; i++) {
709                        lower[i] -= radius;
710                        upper[i] += radius;
711                }
712
713                final float radSq = radius * radius;
714                rangeSearch(lower, upper, new TIntObjectProcedure<short[]>() {
715                        @Override
716                        public boolean execute(int idx, short[] point) {
717                                final float d = distance(centre, point);
718                                if (d <= radSq)
719                                        return proc.execute(idx, point);
720
721                                return true;
722                        }
723                });
724        }
725
726        /**
727         * Search the tree for all points contained within the hyperrectangle
728         * defined by the given upper and lower extremes. Each valid point that is
729         * found is reported to the given processor together with its index in the
730         * original data.
731         * <p>
732         * The search can be stopped early by returning false from the
733         * {@link TIntObjectProcedure#execute(int, Object)} method.
734         * 
735         * @param lowerExtreme
736         *            the lower extreme of the hyperrectangle
737         * @param upperExtreme
738         *            the upper extreme of the hyperrectangle
739         * @param proc
740         *            the processor
741         */
742        public void rangeSearch(short[] lowerExtreme, short[] upperExtreme, TIntObjectProcedure<short[]> proc) {
743                final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>();
744
745                if (root == null)
746                        return;
747
748                stack.push(root);
749
750                while (!stack.isEmpty()) {
751                        final KDTreeNode tmpNode = stack.pop();
752
753                        if (tmpNode.isLeaf()) {
754                                for (int i = 0; i < tmpNode.indices.length; i++) {
755                                        final int idx = tmpNode.indices[i];
756                                        final short[] vec = data[idx];
757                                        if (isContained(vec, lowerExtreme, upperExtreme))
758                                                if (!proc.execute(idx, vec))
759                                                        return;
760                                }
761                        } else {
762                                if (tmpNode.isDisjointFrom(lowerExtreme, upperExtreme)) {
763                                        continue;
764                                }
765
766                                if (tmpNode.isContainedBy(lowerExtreme, upperExtreme)) {
767                                        reportSubtree(tmpNode, proc);
768                                } else {
769                                        if (tmpNode.left != null)
770                                                stack.push(tmpNode.left);
771                                        if (tmpNode.right != null)
772                                                stack.push(tmpNode.right);
773                                }
774                        }
775                }
776        }
777
778        /**
779         * Determines if a point is contained within a given k-dimensional bounding
780         * box.
781         */
782        private final boolean isContained(short[] point, short[] lower, short[] upper)
783        {
784                for (int i = 0; i < point.length; i++) {
785                        if (point[i] < lower[i] || point[i] > upper[i])
786                                return false;
787                }
788
789                return true;
790        }
791
792        /**
793         * Report all the child items of the given subtree to the process
794         * 
795         * @param root
796         *            the root of the subtree
797         * @param proc
798         *            the process to apply
799         */
800        private void reportSubtree(KDTreeNode root, TIntObjectProcedure<short[]> proc) {
801                final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>();
802                stack.push(root);
803
804                while (!stack.isEmpty()) {
805                        final KDTreeNode tmpNode = stack.pop();
806
807                        if (tmpNode.isLeaf()) {
808                                for (int i = 0; i < tmpNode.indices.length; i++) {
809                                        final int idx = tmpNode.indices[i];
810                                        if (!proc.execute(idx, data[idx]))
811                                                return;
812                                }
813                        } else {
814                                if (tmpNode.left != null)
815                                        stack.push(tmpNode.left);
816                                if (tmpNode.right != null)
817                                        stack.push(tmpNode.right);
818                        }
819                }
820        }
821
822        /**
823         * Nearest-neighbour search
824         * 
825         * @param qu
826         *            the query point
827         * @param n
828         *            the number of neighbours to find
829         * @return the indices and distances
830         */
831        public List<IntFloatPair> nearestNeighbours(final short[] qu, int n) {
832                final BoundedPriorityQueue<IntFloatPair> queue = new BoundedPriorityQueue<IntFloatPair>(n,
833                                IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
834
835                searchSubTree(qu, root, queue);
836
837                return queue.toOrderedListDestructive();
838        }
839
840        /**
841         * Nearest-neighbour search
842         * 
843         * @param qu
844         *            the query point
845         * @return the indices and distances
846         */
847        public IntFloatPair nearestNeighbour(final short[] qu) {
848                final BoundedPriorityQueue<IntFloatPair> queue = new BoundedPriorityQueue<IntFloatPair>(1,
849                                IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
850
851                searchSubTree(qu, root, queue);
852
853                return queue.peek();
854        }
855
856        /**
857         * Find all the points within the given radius of the given point.
858         * Internally this works by finding the points in the hyper-square
859         * encompassing the hyper-circle and then filtering.
860         * 
861         * @param centre
862         *            the centre point
863         * @param radius
864         *            the radius
865         * @return the points
866         */
867        public short[][] coordinateRadiusSearch(short[] centre, short radius) {
868                final List<short[]> radiusList = new ArrayList<short[]>();
869
870                coordinateRadiusSearch(centre, radius, new TObjectFloatProcedure<short[]>() {
871                        @Override
872                        public boolean execute(short[] a, float b) {
873                                radiusList.add(a);
874                                return true;
875                        }
876                });
877
878                return radiusList.toArray(new short[radiusList.size()][]);
879        }
880
881        /**
882         * Find all the points within the given radius of the given point.
883         * Internally this works by finding the points in the hyper-square
884         * encompassing the hyper-circle and then filtering. Each valid point that
885         * is found is reported to the given processor together with its distance
886         * from the centre.
887         * <p>
888         * The search can be stopped early by returning false from the
889         * {@link TIntObjectProcedure#execute(int, Object)} method.
890         * 
891         * @param centre
892         *            the centre point
893         * @param radius
894         *            the radius
895         * @param proc
896         *            the process
897         */
898        public void coordinateRadiusSearch(final short[] centre, short radius, final TObjectFloatProcedure<short[]> proc)
899        {
900                final short[] lower = centre.clone();
901                final short[] upper = centre.clone();
902
903                for (int i = 0; i < centre.length; i++) {
904                        lower[i] -= radius;
905                        upper[i] += radius;
906                }
907
908                final float radSq = radius * radius;
909                rangeSearch(lower, upper, new TIntObjectProcedure<short[]>() {
910                        @Override
911                        public boolean execute(int idx, short[] point) {
912                                final float d = distance(centre, point);
913                                if (d <= radSq)
914                                        return proc.execute(point, d);
915
916                                return true;
917                        }
918                });
919        }
920
921        private void searchSubTree(final short[] qu, KDTreeNode cur, BoundedPriorityQueue<IntFloatPair> queue) {
922                final Deque<KDTreeNode> stack = new ArrayDeque<KDTreeNode>();
923                while (!cur.isLeaf()) {
924                        stack.push(cur);
925
926                        final float diff = qu[cur.discriminantDimension] - cur.discriminant;
927
928                        if (diff < 0) {
929                                cur = cur.left;
930                        } else {
931                                cur = cur.right;
932                        }
933                }
934
935                for (int i = 0; i < cur.indices.length; i++) {
936                        final int idx = cur.indices[i];
937                        final short[] vec = data[idx];
938                        final float dist = distance(qu, vec);
939                        queue.add(new IntFloatPair(idx, dist));
940                }
941
942                while (!stack.isEmpty()) {
943                        cur = stack.pop();
944                        final float diff = qu[cur.discriminantDimension] - cur.discriminant;
945
946                        final float worstDist = queue.peekTail().second;
947
948                        if (diff * diff <= worstDist || !queue.isFull()) {
949                                // need to search subtree
950                                if (diff < 0) {
951                                        searchSubTree(qu, cur.right, queue);
952                                } else {
953                                        searchSubTree(qu, cur.left, queue);
954                                }
955                        }
956                }
957        }
958
959        private float distance(short[] qu, short[] vec) {
960                float d = 0;
961                for (int i = 0; i < qu.length; i++)
962                        d += (qu[i] - vec[i]) * (qu[i] - vec[i]);
963                return d;
964        }
965        
966        /**
967         * Find all the indices seperated by leaves
968         * @return all the leaves
969         */
970        public List<int[]> leafIndices() {
971                List<int[]> leafInds = new ArrayList<int[]>();
972                Deque<KDTreeNode> nodes = new ArrayDeque<KDTreeNode>();
973                nodes.push(root);
974                while(nodes.size()!=0){
975                        KDTreeNode node = nodes.pop();
976                        if(node.isLeaf()){
977                                leafInds.add(node.indices);
978                        }
979                        else{
980                                nodes.push(node.left);
981                                nodes.push(node.right);
982                        }
983                }
984                
985                return leafInds;
986        }
987}