001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/core/core/src/main/jtemp/org/openimaj/util/tree/Incremental#T#KDTree.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034package org.openimaj.util.tree;
035
036import java.util.ArrayList;
037import java.util.Collection;
038import java.util.List;
039import java.util.PriorityQueue;
040import java.util.Stack;
041
042import org.openimaj.util.pair.ObjectDoublePair;
043import org.openimaj.util.queue.BoundedPriorityQueue;
044
045/**
046 * Implementation of a simple incremental KDTree for <code>short[]</code>s. Includes
047 * support for range search, neighbour search, and radius search. The tree created 
048 * by this class will usually be rather unbalanced.
049 * <p>
050 * The KDTree allows fast search for points in relatively low-dimension spaces.
051 * 
052 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
053 */
054public class IncrementalShortKDTree {
055    private static class KDNode {
056        int discriminateDim;
057        short[] point;
058        KDNode left, right;
059
060        KDNode(short[] point, int discriminate) {
061                this.point = point;
062                this.left = right = null;
063                this.discriminateDim = discriminate;
064        }
065    }
066    
067        KDNode _root;
068
069        /**
070         * Create an empty KDTree object
071         */
072        public IncrementalShortKDTree() {
073                _root = null;
074        }
075
076        /**
077         * Create a KDTree object and populate it with the given data.
078         * 
079         * @param coords
080         *            the data to populate the index with.
081         */
082        public IncrementalShortKDTree(Collection<short[]> coords) {
083                _root = null;
084                insertAll(coords);
085        }
086        
087        /**
088         * Create a KDTree object and populate it with the given data.
089         * 
090         * @param coords
091         *            the data to populate the index with.
092         */
093        public IncrementalShortKDTree(short[][] coords) {
094                _root = null;
095                insertAll(coords);
096        }
097
098        /**
099         * Insert all the points from the given collection into the index.
100         * 
101         * @param coords
102         *            The points to add.
103         */
104        public void insertAll(Collection<short[]> coords) {
105                for (final short[] c : coords)
106                        insert(c);
107        }
108        
109        /**
110         * Insert all the points from the given collection into the index.
111         * 
112         * @param coords
113         *            The points to add.
114         */
115        public void insertAll(short[][] coords) {
116                for (final short[] c : coords)
117                        insert(c);
118        }
119
120        /**
121         * Inserts a point into the tree, preserving the spatial ordering.
122         * 
123         * @param point
124         *            Point to insert.
125         */
126        public void insert(short[] point) {
127
128                if (_root == null)
129                        _root = new KDNode(point, 0);
130                else {
131                        int discriminate;
132                        KDNode curNode, tmpNode;
133                        double ordinate1, ordinate2;
134
135                        curNode = _root;
136
137                        do {
138                                tmpNode = curNode;
139                                discriminate = tmpNode.discriminateDim;
140
141                                ordinate1 = point[discriminate];
142                                ordinate2 = tmpNode.point[discriminate];
143
144                                if (ordinate1 > ordinate2)
145                                        curNode = tmpNode.right;
146                                else
147                                        curNode = tmpNode.left;
148                        } while (curNode != null);
149
150                        if (++discriminate >= point.length)
151                                discriminate = 0;
152
153                        if (ordinate1 > ordinate2)
154                                tmpNode.right = new KDNode(point, discriminate);
155                        else
156                                tmpNode.left = new KDNode(point, discriminate);
157                }
158        }
159
160        /**
161         * Determines if a point is contained within a given k-dimensional bounding
162         * box.
163         */
164        static final boolean isContained(
165                        short[] point, short[] lower, short[] upper)
166        {
167                double ordinate1, ordinate2, ordinate3;
168
169                for (int i = 0; i < point.length; i++) {
170                        ordinate1 = point[i];
171                        ordinate2 = lower[i];
172                        ordinate3 = upper[i];
173
174                        if (ordinate1 < ordinate2 || ordinate1 > ordinate3)
175                                return false;
176                }
177
178                return true;
179        }
180
181        /**
182         * Searches the tree for all points contained within the bounding box
183         * defined by the given upper and lower extremes
184         * 
185         * @param lowerExtreme
186         * @param upperExtreme
187         * @return the points within the given bounds
188         */
189        public List<short[]> rangeSearch(short[] lowerExtreme, short[] upperExtreme) {
190                final ArrayList<short[]> results = new ArrayList<short[]>(1000);
191                final Stack<KDNode> stack = new Stack<KDNode>();
192                KDNode tmpNode;
193                int discriminate;
194                double ordinate1, ordinate2;
195
196                if (_root == null)
197                        return results;
198
199                stack.push(_root);
200
201                while (!stack.empty()) {
202                        tmpNode = stack.pop();
203                        discriminate = tmpNode.discriminateDim;
204
205                        ordinate1 = tmpNode.point[discriminate];
206                        ordinate2 = lowerExtreme[discriminate];
207
208                        if (ordinate1 >= ordinate2 && tmpNode.left != null)
209                                stack.push(tmpNode.left);
210
211                        ordinate2 = upperExtreme[discriminate];
212
213                        if (ordinate1 <= ordinate2 && tmpNode.right != null)
214                                stack.push(tmpNode.right);
215
216                        if (isContained(tmpNode.point, lowerExtreme, upperExtreme))
217                                results.add(tmpNode.point);
218                }
219
220                return results;
221        }
222
223        protected static final double distance(short[] a, short[] b) {
224                double s = 0;
225
226                for (int i = 0; i < a.length; i++) {
227                        final double fa = a[i];
228                        final double fb = b[i];
229                        s += (fa - fb) * (fa - fb);
230                }
231                return s;
232        }
233
234        /**
235         * Find the nearest neighbour. Only one neighbour will be returned - if
236         * multiple neighbours share the same location, or are equidistant, then
237         * this might not be the one you expect.
238         * 
239         * @param query
240         *            query coordinate
241         * @return nearest neighbour
242         */
243        public ObjectDoublePair<short[]> findNearestNeighbour(short[] query) {
244                final Stack<KDNode> stack = walkdown(query);
245                final ObjectDoublePair<short[]> state = new ObjectDoublePair<short[]>();
246                state.first = stack.peek().point;
247                state.second = distance(query, state.first);
248
249                if (state.second == 0)
250                        return state;
251
252                while (!stack.isEmpty()) {
253                        final KDNode current = stack.pop();
254
255                        checkSubtree(current, query, state);
256                }
257
258                return state;
259        }
260
261        /**
262         * Find the K nearest neighbours.
263         * 
264         * @param query
265         *            query coordinate
266         * @param k
267         *            the number of neighbours to find
268         * @return nearest neighbours
269         */
270        public List<ObjectDoublePair<short[]>> findNearestNeighbours(short[] query, int k) {
271                final Stack<KDNode> stack = walkdown(query);
272                final BoundedPriorityQueue<ObjectDoublePair<short[]>> state = new BoundedPriorityQueue<ObjectDoublePair<short[]>>(
273                                k, ObjectDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
274
275                final ObjectDoublePair<short[]> initialState = new ObjectDoublePair<short[]>();
276                initialState.first = stack.peek().point;
277                initialState.second = distance(query, initialState.first);
278                state.add(initialState);
279
280                while (!stack.isEmpty()) {
281                        final KDNode current = stack.pop();
282
283                        checkSubtreeK(current, query, state, k);
284                }
285
286                return state.toOrderedListDestructive();
287        }
288
289        /*
290         * Check a subtree for a closer match
291         */
292        private void checkSubtree(KDNode node, short[] query, ObjectDoublePair<short[]> state) {
293                if (node == null)
294                        return;
295
296                final double dist = distance(query, node.point);
297                if (dist < state.second) {
298                        state.first = node.point;
299                        state.second = dist;
300                }
301
302                if (state.second == 0)
303                        return;
304
305                final double d = node.point[node.discriminateDim] - query[node.discriminateDim];
306                if (d * d > state.second) {
307                        // check subtree
308                        final double ordinate1 = query[node.discriminateDim];
309                        final double ordinate2 = node.point[node.discriminateDim];
310
311                        if (ordinate1 > ordinate2)
312                                checkSubtree(node.right, query, state);
313                        else
314                                checkSubtree(node.left, query, state);
315                } else {
316                        checkSubtree(node.left, query, state);
317                        checkSubtree(node.right, query, state);
318                }
319        }
320
321        private void checkSubtreeK(KDNode node, short[] query, PriorityQueue<ObjectDoublePair<short[]>> state, int k) {
322                if (node == null)
323                        return;
324
325                final double dist = distance(query, node.point);
326
327                boolean cont = false;
328                for (final ObjectDoublePair<short[]> s : state)
329                        if (s.first.equals(node.point)) {
330                                cont = true;
331                                break;
332                        }
333
334                if (!cont) {
335                        if (state.size() < k) {
336                                // collect this node
337                                final ObjectDoublePair<short[]> s = new ObjectDoublePair<short[]>();
338                                s.first = node.point;
339                                s.second = dist;
340                                state.add(s);
341                        } else if (dist < state.peek().second) {
342                                // replace last node
343                                final ObjectDoublePair<short[]> s = state.poll();
344                                s.first = node.point;
345                                s.second = dist;
346                                state.add(s);
347                        }
348                }
349
350                final double d = node.point[node.discriminateDim] - query[node.discriminateDim];
351                if (d * d > state.peek().second) {
352                        // check subtree
353                        final double ordinate1 = query[node.discriminateDim];
354                        final double ordinate2 = node.point[node.discriminateDim];
355
356                        if (ordinate1 > ordinate2)
357                                checkSubtreeK(node.right, query, state, k);
358                        else
359                                checkSubtreeK(node.left, query, state, k);
360                } else {
361                        checkSubtreeK(node.left, query, state, k);
362                        checkSubtreeK(node.right, query, state, k);
363                }
364        }
365
366        /*
367         * walk down the tree until we hit a leaf, and return the path taken
368         */
369        private Stack<KDNode> walkdown(short[] point) {
370                if (_root == null)
371                        return null;
372                else {
373                        final Stack<KDNode> stack = new Stack<KDNode>();
374                        int discriminate;
375                        KDNode curNode, tmpNode;
376                        double ordinate1, ordinate2;
377
378                        curNode = _root;
379
380                        do {
381                                tmpNode = curNode;
382                                stack.push(tmpNode);
383                                if (tmpNode.point == point)
384                                        return stack;
385                                discriminate = tmpNode.discriminateDim;
386
387                                ordinate1 = point[discriminate];
388                                ordinate2 = tmpNode.point[discriminate];
389
390                                if (ordinate1 > ordinate2)
391                                        curNode = tmpNode.right;
392                                else
393                                        curNode = tmpNode.left;
394                        } while (curNode != null);
395
396                        if (++discriminate >= point.length)
397                                discriminate = 0;
398
399                        return stack;
400                }
401        }
402        
403        /**
404         * Find all the points within the given radius of the given point
405         * 
406         * @param centre
407         *            the centre point
408         * @param radius
409         *            the radius
410         * @return the points
411         */
412        public List<short[]> radiusSearch(short[] centre, short radius) {
413                final short[] lower = centre.clone();
414                final short[] upper = centre.clone();
415
416                for (int i = 0; i < centre.length; i++) {
417                        lower[i] -= radius;
418                        upper[i] += radius;
419                }
420
421                final List<short[]> rangeList = rangeSearch(lower, upper);
422                final List<short[]> radiusList = new ArrayList<short[]>(rangeList.size());
423                final double radSq = radius * radius;
424                for (final short[] r : rangeList) {
425                        if (distance(centre, r) < radSq)
426                                radiusList.add(r);
427                }
428
429                return radiusList;
430        }
431        
432        /**
433         * Find all the points within the given radius of the given point. 
434         * Returns the distance to the point as well as the point itself. Distance
435         * is the squared L2 distance.
436         * 
437         * @param centre
438         *            the centre point
439         * @param radius
440         *            the radius
441         * @return the points and distances
442         */
443        public List<ObjectDoublePair<short[]>> radiusDistanceSearch(short[] centre, short radius) {
444                final short[] lower = centre.clone();
445                final short[] upper = centre.clone();
446
447                for (int i = 0; i < centre.length; i++) {
448                        lower[i] -= radius;
449                        upper[i] += radius;
450                }
451
452                final List<short[]> rangeList = rangeSearch(lower, upper);
453                final List<ObjectDoublePair<short[]>> radiusList = new ArrayList<ObjectDoublePair<short[]>>(rangeList.size());
454                final double radSq = radius * radius;
455                for (final short[] r : rangeList) {
456                        double dist = distance(centre, r);
457                        if (dist < radSq)
458                                radiusList.add(new ObjectDoublePair<short[]>(r, dist));
459                }
460
461                return radiusList;
462        }
463}