001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.knn;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.lang.reflect.Array;
037import java.util.ArrayList;
038import java.util.Arrays;
039import java.util.Collection;
040import java.util.List;
041import java.util.PriorityQueue;
042import java.util.Scanner;
043import java.util.Stack;
044
045import org.openimaj.math.geometry.point.Coordinate;
046
047class KDNode<T extends Coordinate> {
048        int _discriminate;
049        T _point;
050        KDNode<T> _left, _right;
051
052        KDNode(T point, int discriminate) {
053                _point = point;
054                _left = _right = null;
055                _discriminate = discriminate;
056        }
057}
058
059/**
060 * Implementation of a simple KDTree with range search. The KDTree allows fast
061 * search for points in relatively low-dimension spaces.
062 *
063 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
064 *
065 * @param <T>
066 *            the type of Coordinate.
067 */
068public class CoordinateKDTree<T extends Coordinate> implements CoordinateIndex<T> {
069        KDNode<T> _root;
070
071        /**
072         * Create an empty KDTree object
073         */
074        public CoordinateKDTree() {
075                _root = null;
076        }
077
078        /**
079         * Create a KDTree object and populate it with the given data.
080         * 
081         * @param coords
082         *            the data to populate the index with.
083         */
084        public CoordinateKDTree(Collection<T> coords) {
085                _root = null;
086                insertAll(coords);
087        }
088
089        /**
090         * Insert all the points from the given collection into the index.
091         * 
092         * @param coords
093         *            The points to add.
094         */
095        public void insertAll(Collection<T> coords) {
096                for (final T c : coords)
097                        insert(c);
098        }
099
100        /**
101         * Inserts a point into the tree, preserving the spatial ordering.
102         * 
103         * @param point
104         *            Point to insert.
105         */
106        @Override
107        public void insert(T point) {
108
109                if (_root == null)
110                        _root = new KDNode<T>(point, 0);
111                else {
112                        int discriminate, dimensions;
113                        KDNode<T> curNode, tmpNode;
114                        double ordinate1, ordinate2;
115
116                        curNode = _root;
117
118                        do {
119                                tmpNode = curNode;
120                                discriminate = tmpNode._discriminate;
121
122                                ordinate1 = point.getOrdinate(discriminate).doubleValue();
123                                ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue();
124
125                                if (ordinate1 > ordinate2)
126                                        curNode = tmpNode._right;
127                                else
128                                        curNode = tmpNode._left;
129                        } while (curNode != null);
130
131                        dimensions = point.getDimensions();
132
133                        if (++discriminate >= dimensions)
134                                discriminate = 0;
135
136                        if (ordinate1 > ordinate2)
137                                tmpNode._right = new KDNode<T>(point, discriminate);
138                        else
139                                tmpNode._left = new KDNode<T>(point, discriminate);
140                }
141        }
142
143        /**
144         * Determines if a point is contained within a given k-dimensional bounding
145         * box.
146         */
147        static final boolean isContained(
148                        Coordinate point, Coordinate lower, Coordinate upper)
149        {
150                int dimensions;
151                double ordinate1, ordinate2, ordinate3;
152
153                dimensions = point.getDimensions();
154
155                for (int i = 0; i < dimensions; ++i) {
156                        ordinate1 = point.getOrdinate(i).doubleValue();
157                        ordinate2 = lower.getOrdinate(i).doubleValue();
158                        ordinate3 = upper.getOrdinate(i).doubleValue();
159
160                        if (ordinate1 < ordinate2 || ordinate1 > ordinate3)
161                                return false;
162                }
163
164                return true;
165        }
166
167        /**
168         * Searches the tree for all points contained within a given k-dimensional
169         * bounding box and stores them in a Collection.
170         * 
171         * @param results
172         * @param lowerExtreme
173         * @param upperExtreme
174         */
175        @Override
176        public void rangeSearch(Collection<T> results, Coordinate lowerExtreme, Coordinate upperExtreme)
177        {
178                KDNode<T> tmpNode;
179                final Stack<KDNode<T>> stack = new Stack<KDNode<T>>();
180                int discriminate;
181                double ordinate1, ordinate2;
182
183                if (_root == null)
184                        return;
185
186                stack.push(_root);
187
188                while (!stack.empty()) {
189                        tmpNode = stack.pop();
190                        discriminate = tmpNode._discriminate;
191
192                        ordinate1 = tmpNode._point.getOrdinate(discriminate).doubleValue();
193                        ordinate2 = lowerExtreme.getOrdinate(discriminate).doubleValue();
194
195                        if (ordinate1 > ordinate2 && tmpNode._left != null)
196                                stack.push(tmpNode._left);
197
198                        ordinate2 = upperExtreme.getOrdinate(discriminate).doubleValue();
199
200                        if (ordinate1 < ordinate2 && tmpNode._right != null)
201                                stack.push(tmpNode._right);
202
203                        if (isContained(tmpNode._point, lowerExtreme, upperExtreme))
204                                results.add(tmpNode._point);
205                }
206        }
207
208        protected static final float distance(Coordinate a, Coordinate b) {
209                float s = 0;
210
211                for (int i = 0; i < a.getDimensions(); i++) {
212                        final float fa = a.getOrdinate(i).floatValue();
213                        final float fb = b.getOrdinate(i).floatValue();
214                        s += (fa - fb) * (fa - fb);
215                }
216                return s;
217        }
218
219        class NNState implements Comparable<NNState> {
220                T best;
221                float bestDist;
222
223                @Override
224                public int compareTo(NNState o) {
225                        if (bestDist < o.bestDist)
226                                return 1;
227                        if (bestDist > o.bestDist)
228                                return -1;
229                        return 0;
230                }
231
232                @Override
233                public String toString() {
234                        return bestDist + "";
235                }
236        }
237
238        /**
239         * Find the nearest neighbour. Only one neighbour will be returned - if
240         * multiple neighbours share the same location, or are equidistant, then
241         * this might not be the one you expect.
242         * 
243         * @param query
244         *            query coordinate
245         * @return nearest neighbour
246         */
247        @Override
248        public T nearestNeighbour(Coordinate query) {
249                final Stack<KDNode<T>> stack = walkdown(query);
250                final NNState state = new NNState();
251                state.best = stack.peek()._point;
252                state.bestDist = distance(query, state.best);
253
254                if (state.bestDist == 0)
255                        return state.best;
256
257                while (!stack.isEmpty()) {
258                        final KDNode<T> current = stack.pop();
259
260                        checkSubtree(current, query, state);
261                }
262
263                return state.best;
264        }
265
266        @Override
267        public void kNearestNeighbour(Collection<T> result, Coordinate query, int k) {
268                final Stack<KDNode<T>> stack = walkdown(query);
269                final PriorityQueue<NNState> state = new PriorityQueue<NNState>(k);
270
271                final NNState initialState = new NNState();
272                initialState.best = stack.peek()._point;
273                initialState.bestDist = distance(query, initialState.best);
274                state.add(initialState);
275
276                while (!stack.isEmpty()) {
277                        final KDNode<T> current = stack.pop();
278
279                        checkSubtreeK(current, query, state, k);
280                }
281
282                @SuppressWarnings("unchecked")
283                final NNState[] stateList = state.toArray((NNState[]) Array.newInstance(NNState.class, state.size()));
284                Arrays.sort(stateList);
285                for (int i = stateList.length - 1; i >= 0; i--)
286                        result.add(stateList[i].best);
287        }
288
289        /*
290         * Check a subtree for a closer match
291         */
292        private void checkSubtree(KDNode<T> node, Coordinate query, NNState state) {
293                if (node == null)
294                        return;
295
296                final float dist = distance(query, node._point);
297                if (dist < state.bestDist) {
298                        state.best = node._point;
299                        state.bestDist = dist;
300                }
301
302                if (state.bestDist == 0)
303                        return;
304
305                final float d = node._point.getOrdinate(node._discriminate).floatValue() -
306                                query.getOrdinate(node._discriminate).floatValue();
307                if (d * d > state.bestDist) {
308                        // check subtree
309                        final double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
310                        final double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue();
311
312                        if (ordinate1 > ordinate2)
313                                checkSubtree(node._right, query, state);
314                        else
315                                checkSubtree(node._left, query, state);
316                } else {
317                        checkSubtree(node._left, query, state);
318                        checkSubtree(node._right, query, state);
319                }
320        }
321
322        private void checkSubtreeK(KDNode<T> node, Coordinate query, PriorityQueue<NNState> state, int k) {
323                if (node == null)
324                        return;
325
326                final float dist = distance(query, node._point);
327
328                boolean cont = false;
329                for (final NNState s : state)
330                        if (s.best.equals(node._point)) {
331                                cont = true;
332                                break;
333                        }
334
335                if (!cont) {
336                        if (state.size() < k) {
337                                // collect this node
338                                final NNState s = new NNState();
339                                s.best = node._point;
340                                s.bestDist = dist;
341                                state.add(s);
342                        } else if (dist < state.peek().bestDist) {
343                                // replace last node
344                                final NNState s = state.poll();
345                                s.best = node._point;
346                                s.bestDist = dist;
347                                state.add(s);
348                        }
349                }
350
351                final float d = node._point.getOrdinate(node._discriminate).floatValue() -
352                                query.getOrdinate(node._discriminate).floatValue();
353                if (d * d > state.peek().bestDist) {
354                        // check subtree
355                        final double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
356                        final double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue();
357
358                        if (ordinate1 > ordinate2)
359                                checkSubtreeK(node._right, query, state, k);
360                        else
361                                checkSubtreeK(node._left, query, state, k);
362                } else {
363                        checkSubtreeK(node._left, query, state, k);
364                        checkSubtreeK(node._right, query, state, k);
365                }
366        }
367
368        /*
369         * walk down the tree until we hit a leaf, and return the path taken
370         */
371        private Stack<KDNode<T>> walkdown(Coordinate point) {
372                if (_root == null)
373                        return null;
374                else {
375                        final Stack<KDNode<T>> stack = new Stack<KDNode<T>>();
376                        int discriminate, dimensions;
377                        KDNode<T> curNode, tmpNode;
378                        double ordinate1, ordinate2;
379
380                        curNode = _root;
381
382                        do {
383                                tmpNode = curNode;
384                                stack.push(tmpNode);
385                                if (tmpNode._point == point)
386                                        return stack;
387                                discriminate = tmpNode._discriminate;
388
389                                ordinate1 = point.getOrdinate(discriminate).doubleValue();
390                                ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue();
391
392                                if (ordinate1 > ordinate2)
393                                        curNode = tmpNode._right;
394                                else
395                                        curNode = tmpNode._left;
396                        } while (curNode != null);
397
398                        dimensions = point.getDimensions();
399
400                        if (++discriminate >= dimensions)
401                                discriminate = 0;
402
403                        return stack;
404                }
405        }
406
407        class Coord implements Coordinate {
408                float[] coords;
409
410                public Coord(int i) {
411                        coords = new float[i];
412                }
413
414                public Coord(Coordinate c) {
415                        this(c.getDimensions());
416                        for (int i = 0; i < coords.length; i++)
417                                coords[i] = c.getOrdinate(i).floatValue();
418                }
419
420                @Override
421                public int getDimensions() {
422                        return coords.length;
423                }
424
425                @Override
426                public Float getOrdinate(int dimension) {
427                        return coords[dimension];
428                }
429
430                @Override
431                public void readASCII(Scanner in) throws IOException {
432                        throw new RuntimeException("not implemented");
433                }
434
435                @Override
436                public String asciiHeader() {
437                        throw new RuntimeException("not implemented");
438                }
439
440                @Override
441                public void readBinary(DataInput in) throws IOException {
442                        throw new RuntimeException("not implemented");
443                }
444
445                @Override
446                public byte[] binaryHeader() {
447                        throw new RuntimeException("not implemented");
448                }
449
450                @Override
451                public void writeASCII(PrintWriter out) throws IOException {
452                        throw new RuntimeException("not implemented");
453                }
454
455                @Override
456                public void writeBinary(DataOutput out) throws IOException {
457                        throw new RuntimeException("not implemented");
458                }
459
460                @Override
461                public void setOrdinate(int dimension, Number value) {
462                        coords[dimension] = value.floatValue();
463                }
464        }
465
466        /**
467         * Faster implementation of K-nearest-neighbours.
468         *
469         * @param result
470         *            Collection to hold the found coordinates.
471         * @param query
472         *            The query coordinate.
473         * @param k
474         *            The number of neighbours to find.
475         */
476        public void fastKNN(Collection<T> result, Coordinate query, int k) {
477                final List<T> tmp = new ArrayList<T>();
478                final Coord lowerExtreme = new Coord(query);
479                final Coord upperExtreme = new Coord(query);
480
481                while (tmp.size() < k) {
482                        tmp.clear();
483                        for (int i = 0; i < lowerExtreme.getDimensions(); i++)
484                                lowerExtreme.coords[i] -= k;
485                        for (int i = 0; i < upperExtreme.getDimensions(); i++)
486                                upperExtreme.coords[i] += k;
487                        rangeSearch(tmp, lowerExtreme, upperExtreme);
488                }
489
490                final CoordinateBruteForce<T> bf = new CoordinateBruteForce<T>(tmp);
491                bf.kNearestNeighbour(result, query, k);
492        }
493}