001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.knn; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035import java.io.PrintWriter; 036import java.lang.reflect.Array; 037import java.util.ArrayList; 038import java.util.Arrays; 039import java.util.Collection; 040import java.util.List; 041import java.util.PriorityQueue; 042import java.util.Scanner; 043import java.util.Stack; 044 045import org.openimaj.math.geometry.point.Coordinate; 046 047class KDNode<T extends Coordinate> { 048 int _discriminate; 049 T _point; 050 KDNode<T> _left, _right; 051 052 KDNode(T point, int discriminate) { 053 _point = point; 054 _left = _right = null; 055 _discriminate = discriminate; 056 } 057} 058 059/** 060 * Implementation of a simple KDTree with range search. The KDTree allows fast 061 * search for points in relatively low-dimension spaces. 062 * 063 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 064 * 065 * @param <T> 066 * the type of Coordinate. 067 */ 068public class CoordinateKDTree<T extends Coordinate> implements CoordinateIndex<T> { 069 KDNode<T> _root; 070 071 /** 072 * Create an empty KDTree object 073 */ 074 public CoordinateKDTree() { 075 _root = null; 076 } 077 078 /** 079 * Create a KDTree object and populate it with the given data. 080 * 081 * @param coords 082 * the data to populate the index with. 083 */ 084 public CoordinateKDTree(Collection<T> coords) { 085 _root = null; 086 insertAll(coords); 087 } 088 089 /** 090 * Insert all the points from the given collection into the index. 091 * 092 * @param coords 093 * The points to add. 094 */ 095 public void insertAll(Collection<T> coords) { 096 for (final T c : coords) 097 insert(c); 098 } 099 100 /** 101 * Inserts a point into the tree, preserving the spatial ordering. 102 * 103 * @param point 104 * Point to insert. 105 */ 106 @Override 107 public void insert(T point) { 108 109 if (_root == null) 110 _root = new KDNode<T>(point, 0); 111 else { 112 int discriminate, dimensions; 113 KDNode<T> curNode, tmpNode; 114 double ordinate1, ordinate2; 115 116 curNode = _root; 117 118 do { 119 tmpNode = curNode; 120 discriminate = tmpNode._discriminate; 121 122 ordinate1 = point.getOrdinate(discriminate).doubleValue(); 123 ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 124 125 if (ordinate1 > ordinate2) 126 curNode = tmpNode._right; 127 else 128 curNode = tmpNode._left; 129 } while (curNode != null); 130 131 dimensions = point.getDimensions(); 132 133 if (++discriminate >= dimensions) 134 discriminate = 0; 135 136 if (ordinate1 > ordinate2) 137 tmpNode._right = new KDNode<T>(point, discriminate); 138 else 139 tmpNode._left = new KDNode<T>(point, discriminate); 140 } 141 } 142 143 /** 144 * Determines if a point is contained within a given k-dimensional bounding 145 * box. 146 */ 147 static final boolean isContained( 148 Coordinate point, Coordinate lower, Coordinate upper) 149 { 150 int dimensions; 151 double ordinate1, ordinate2, ordinate3; 152 153 dimensions = point.getDimensions(); 154 155 for (int i = 0; i < dimensions; ++i) { 156 ordinate1 = point.getOrdinate(i).doubleValue(); 157 ordinate2 = lower.getOrdinate(i).doubleValue(); 158 ordinate3 = upper.getOrdinate(i).doubleValue(); 159 160 if (ordinate1 < ordinate2 || ordinate1 > ordinate3) 161 return false; 162 } 163 164 return true; 165 } 166 167 /** 168 * Searches the tree for all points contained within a given k-dimensional 169 * bounding box and stores them in a Collection. 170 * 171 * @param results 172 * @param lowerExtreme 173 * @param upperExtreme 174 */ 175 @Override 176 public void rangeSearch(Collection<T> results, Coordinate lowerExtreme, Coordinate upperExtreme) 177 { 178 KDNode<T> tmpNode; 179 final Stack<KDNode<T>> stack = new Stack<KDNode<T>>(); 180 int discriminate; 181 double ordinate1, ordinate2; 182 183 if (_root == null) 184 return; 185 186 stack.push(_root); 187 188 while (!stack.empty()) { 189 tmpNode = stack.pop(); 190 discriminate = tmpNode._discriminate; 191 192 ordinate1 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 193 ordinate2 = lowerExtreme.getOrdinate(discriminate).doubleValue(); 194 195 if (ordinate1 > ordinate2 && tmpNode._left != null) 196 stack.push(tmpNode._left); 197 198 ordinate2 = upperExtreme.getOrdinate(discriminate).doubleValue(); 199 200 if (ordinate1 < ordinate2 && tmpNode._right != null) 201 stack.push(tmpNode._right); 202 203 if (isContained(tmpNode._point, lowerExtreme, upperExtreme)) 204 results.add(tmpNode._point); 205 } 206 } 207 208 protected static final float distance(Coordinate a, Coordinate b) { 209 float s = 0; 210 211 for (int i = 0; i < a.getDimensions(); i++) { 212 final float fa = a.getOrdinate(i).floatValue(); 213 final float fb = b.getOrdinate(i).floatValue(); 214 s += (fa - fb) * (fa - fb); 215 } 216 return s; 217 } 218 219 class NNState implements Comparable<NNState> { 220 T best; 221 float bestDist; 222 223 @Override 224 public int compareTo(NNState o) { 225 if (bestDist < o.bestDist) 226 return 1; 227 if (bestDist > o.bestDist) 228 return -1; 229 return 0; 230 } 231 232 @Override 233 public String toString() { 234 return bestDist + ""; 235 } 236 } 237 238 /** 239 * Find the nearest neighbour. Only one neighbour will be returned - if 240 * multiple neighbours share the same location, or are equidistant, then 241 * this might not be the one you expect. 242 * 243 * @param query 244 * query coordinate 245 * @return nearest neighbour 246 */ 247 @Override 248 public T nearestNeighbour(Coordinate query) { 249 final Stack<KDNode<T>> stack = walkdown(query); 250 final NNState state = new NNState(); 251 state.best = stack.peek()._point; 252 state.bestDist = distance(query, state.best); 253 254 if (state.bestDist == 0) 255 return state.best; 256 257 while (!stack.isEmpty()) { 258 final KDNode<T> current = stack.pop(); 259 260 checkSubtree(current, query, state); 261 } 262 263 return state.best; 264 } 265 266 @Override 267 public void kNearestNeighbour(Collection<T> result, Coordinate query, int k) { 268 final Stack<KDNode<T>> stack = walkdown(query); 269 final PriorityQueue<NNState> state = new PriorityQueue<NNState>(k); 270 271 final NNState initialState = new NNState(); 272 initialState.best = stack.peek()._point; 273 initialState.bestDist = distance(query, initialState.best); 274 state.add(initialState); 275 276 while (!stack.isEmpty()) { 277 final KDNode<T> current = stack.pop(); 278 279 checkSubtreeK(current, query, state, k); 280 } 281 282 @SuppressWarnings("unchecked") 283 final NNState[] stateList = state.toArray((NNState[]) Array.newInstance(NNState.class, state.size())); 284 Arrays.sort(stateList); 285 for (int i = stateList.length - 1; i >= 0; i--) 286 result.add(stateList[i].best); 287 } 288 289 /* 290 * Check a subtree for a closer match 291 */ 292 private void checkSubtree(KDNode<T> node, Coordinate query, NNState state) { 293 if (node == null) 294 return; 295 296 final float dist = distance(query, node._point); 297 if (dist < state.bestDist) { 298 state.best = node._point; 299 state.bestDist = dist; 300 } 301 302 if (state.bestDist == 0) 303 return; 304 305 final float d = node._point.getOrdinate(node._discriminate).floatValue() - 306 query.getOrdinate(node._discriminate).floatValue(); 307 if (d * d > state.bestDist) { 308 // check subtree 309 final double ordinate1 = query.getOrdinate(node._discriminate).doubleValue(); 310 final double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue(); 311 312 if (ordinate1 > ordinate2) 313 checkSubtree(node._right, query, state); 314 else 315 checkSubtree(node._left, query, state); 316 } else { 317 checkSubtree(node._left, query, state); 318 checkSubtree(node._right, query, state); 319 } 320 } 321 322 private void checkSubtreeK(KDNode<T> node, Coordinate query, PriorityQueue<NNState> state, int k) { 323 if (node == null) 324 return; 325 326 final float dist = distance(query, node._point); 327 328 boolean cont = false; 329 for (final NNState s : state) 330 if (s.best.equals(node._point)) { 331 cont = true; 332 break; 333 } 334 335 if (!cont) { 336 if (state.size() < k) { 337 // collect this node 338 final NNState s = new NNState(); 339 s.best = node._point; 340 s.bestDist = dist; 341 state.add(s); 342 } else if (dist < state.peek().bestDist) { 343 // replace last node 344 final NNState s = state.poll(); 345 s.best = node._point; 346 s.bestDist = dist; 347 state.add(s); 348 } 349 } 350 351 final float d = node._point.getOrdinate(node._discriminate).floatValue() - 352 query.getOrdinate(node._discriminate).floatValue(); 353 if (d * d > state.peek().bestDist) { 354 // check subtree 355 final double ordinate1 = query.getOrdinate(node._discriminate).doubleValue(); 356 final double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue(); 357 358 if (ordinate1 > ordinate2) 359 checkSubtreeK(node._right, query, state, k); 360 else 361 checkSubtreeK(node._left, query, state, k); 362 } else { 363 checkSubtreeK(node._left, query, state, k); 364 checkSubtreeK(node._right, query, state, k); 365 } 366 } 367 368 /* 369 * walk down the tree until we hit a leaf, and return the path taken 370 */ 371 private Stack<KDNode<T>> walkdown(Coordinate point) { 372 if (_root == null) 373 return null; 374 else { 375 final Stack<KDNode<T>> stack = new Stack<KDNode<T>>(); 376 int discriminate, dimensions; 377 KDNode<T> curNode, tmpNode; 378 double ordinate1, ordinate2; 379 380 curNode = _root; 381 382 do { 383 tmpNode = curNode; 384 stack.push(tmpNode); 385 if (tmpNode._point == point) 386 return stack; 387 discriminate = tmpNode._discriminate; 388 389 ordinate1 = point.getOrdinate(discriminate).doubleValue(); 390 ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 391 392 if (ordinate1 > ordinate2) 393 curNode = tmpNode._right; 394 else 395 curNode = tmpNode._left; 396 } while (curNode != null); 397 398 dimensions = point.getDimensions(); 399 400 if (++discriminate >= dimensions) 401 discriminate = 0; 402 403 return stack; 404 } 405 } 406 407 class Coord implements Coordinate { 408 float[] coords; 409 410 public Coord(int i) { 411 coords = new float[i]; 412 } 413 414 public Coord(Coordinate c) { 415 this(c.getDimensions()); 416 for (int i = 0; i < coords.length; i++) 417 coords[i] = c.getOrdinate(i).floatValue(); 418 } 419 420 @Override 421 public int getDimensions() { 422 return coords.length; 423 } 424 425 @Override 426 public Float getOrdinate(int dimension) { 427 return coords[dimension]; 428 } 429 430 @Override 431 public void readASCII(Scanner in) throws IOException { 432 throw new RuntimeException("not implemented"); 433 } 434 435 @Override 436 public String asciiHeader() { 437 throw new RuntimeException("not implemented"); 438 } 439 440 @Override 441 public void readBinary(DataInput in) throws IOException { 442 throw new RuntimeException("not implemented"); 443 } 444 445 @Override 446 public byte[] binaryHeader() { 447 throw new RuntimeException("not implemented"); 448 } 449 450 @Override 451 public void writeASCII(PrintWriter out) throws IOException { 452 throw new RuntimeException("not implemented"); 453 } 454 455 @Override 456 public void writeBinary(DataOutput out) throws IOException { 457 throw new RuntimeException("not implemented"); 458 } 459 460 @Override 461 public void setOrdinate(int dimension, Number value) { 462 coords[dimension] = value.floatValue(); 463 } 464 } 465 466 /** 467 * Faster implementation of K-nearest-neighbours. 468 * 469 * @param result 470 * Collection to hold the found coordinates. 471 * @param query 472 * The query coordinate. 473 * @param k 474 * The number of neighbours to find. 475 */ 476 public void fastKNN(Collection<T> result, Coordinate query, int k) { 477 final List<T> tmp = new ArrayList<T>(); 478 final Coord lowerExtreme = new Coord(query); 479 final Coord upperExtreme = new Coord(query); 480 481 while (tmp.size() < k) { 482 tmp.clear(); 483 for (int i = 0; i < lowerExtreme.getDimensions(); i++) 484 lowerExtreme.coords[i] -= k; 485 for (int i = 0; i < upperExtreme.getDimensions(); i++) 486 upperExtreme.coords[i] += k; 487 rangeSearch(tmp, lowerExtreme, upperExtreme); 488 } 489 490 final CoordinateBruteForce<T> bf = new CoordinateBruteForce<T>(tmp); 491 bf.kNearestNeighbour(result, query, k); 492 } 493}