001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jsh2/Work/openimaj/target/checkout/core/core/src/main/jtemp/org/openimaj/util/tree/Incremental#T#KDTree.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034package org.openimaj.util.tree; 035 036import java.util.ArrayList; 037import java.util.Collection; 038import java.util.List; 039import java.util.PriorityQueue; 040import java.util.Stack; 041 042import org.openimaj.util.pair.ObjectDoublePair; 043import org.openimaj.util.queue.BoundedPriorityQueue; 044 045/** 046 * Implementation of a simple incremental KDTree for <code>byte[]</code>s. Includes 047 * support for range search, neighbour search, and radius search. The tree created 048 * by this class will usually be rather unbalanced. 049 * <p> 050 * The KDTree allows fast search for points in relatively low-dimension spaces. 051 * 052 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 053 */ 054public class IncrementalByteKDTree { 055 private static class KDNode { 056 int discriminateDim; 057 byte[] point; 058 KDNode left, right; 059 060 KDNode(byte[] point, int discriminate) { 061 this.point = point; 062 this.left = right = null; 063 this.discriminateDim = discriminate; 064 } 065 } 066 067 KDNode _root; 068 069 /** 070 * Create an empty KDTree object 071 */ 072 public IncrementalByteKDTree() { 073 _root = null; 074 } 075 076 /** 077 * Create a KDTree object and populate it with the given data. 078 * 079 * @param coords 080 * the data to populate the index with. 081 */ 082 public IncrementalByteKDTree(Collection<byte[]> coords) { 083 _root = null; 084 insertAll(coords); 085 } 086 087 /** 088 * Create a KDTree object and populate it with the given data. 089 * 090 * @param coords 091 * the data to populate the index with. 092 */ 093 public IncrementalByteKDTree(byte[][] coords) { 094 _root = null; 095 insertAll(coords); 096 } 097 098 /** 099 * Insert all the points from the given collection into the index. 100 * 101 * @param coords 102 * The points to add. 103 */ 104 public void insertAll(Collection<byte[]> coords) { 105 for (final byte[] c : coords) 106 insert(c); 107 } 108 109 /** 110 * Insert all the points from the given collection into the index. 111 * 112 * @param coords 113 * The points to add. 114 */ 115 public void insertAll(byte[][] coords) { 116 for (final byte[] c : coords) 117 insert(c); 118 } 119 120 /** 121 * Inserts a point into the tree, preserving the spatial ordering. 122 * 123 * @param point 124 * Point to insert. 125 */ 126 public void insert(byte[] point) { 127 128 if (_root == null) 129 _root = new KDNode(point, 0); 130 else { 131 int discriminate; 132 KDNode curNode, tmpNode; 133 double ordinate1, ordinate2; 134 135 curNode = _root; 136 137 do { 138 tmpNode = curNode; 139 discriminate = tmpNode.discriminateDim; 140 141 ordinate1 = point[discriminate]; 142 ordinate2 = tmpNode.point[discriminate]; 143 144 if (ordinate1 > ordinate2) 145 curNode = tmpNode.right; 146 else 147 curNode = tmpNode.left; 148 } while (curNode != null); 149 150 if (++discriminate >= point.length) 151 discriminate = 0; 152 153 if (ordinate1 > ordinate2) 154 tmpNode.right = new KDNode(point, discriminate); 155 else 156 tmpNode.left = new KDNode(point, discriminate); 157 } 158 } 159 160 /** 161 * Determines if a point is contained within a given k-dimensional bounding 162 * box. 163 */ 164 static final boolean isContained( 165 byte[] point, byte[] lower, byte[] upper) 166 { 167 double ordinate1, ordinate2, ordinate3; 168 169 for (int i = 0; i < point.length; i++) { 170 ordinate1 = point[i]; 171 ordinate2 = lower[i]; 172 ordinate3 = upper[i]; 173 174 if (ordinate1 < ordinate2 || ordinate1 > ordinate3) 175 return false; 176 } 177 178 return true; 179 } 180 181 /** 182 * Searches the tree for all points contained within the bounding box 183 * defined by the given upper and lower extremes 184 * 185 * @param lowerExtreme 186 * @param upperExtreme 187 * @return the points within the given bounds 188 */ 189 public List<byte[]> rangeSearch(byte[] lowerExtreme, byte[] upperExtreme) { 190 final ArrayList<byte[]> results = new ArrayList<byte[]>(1000); 191 final Stack<KDNode> stack = new Stack<KDNode>(); 192 KDNode tmpNode; 193 int discriminate; 194 double ordinate1, ordinate2; 195 196 if (_root == null) 197 return results; 198 199 stack.push(_root); 200 201 while (!stack.empty()) { 202 tmpNode = stack.pop(); 203 discriminate = tmpNode.discriminateDim; 204 205 ordinate1 = tmpNode.point[discriminate]; 206 ordinate2 = lowerExtreme[discriminate]; 207 208 if (ordinate1 >= ordinate2 && tmpNode.left != null) 209 stack.push(tmpNode.left); 210 211 ordinate2 = upperExtreme[discriminate]; 212 213 if (ordinate1 <= ordinate2 && tmpNode.right != null) 214 stack.push(tmpNode.right); 215 216 if (isContained(tmpNode.point, lowerExtreme, upperExtreme)) 217 results.add(tmpNode.point); 218 } 219 220 return results; 221 } 222 223 protected static final double distance(byte[] a, byte[] b) { 224 double s = 0; 225 226 for (int i = 0; i < a.length; i++) { 227 final double fa = a[i]; 228 final double fb = b[i]; 229 s += (fa - fb) * (fa - fb); 230 } 231 return s; 232 } 233 234 /** 235 * Find the nearest neighbour. Only one neighbour will be returned - if 236 * multiple neighbours share the same location, or are equidistant, then 237 * this might not be the one you expect. 238 * 239 * @param query 240 * query coordinate 241 * @return nearest neighbour 242 */ 243 public ObjectDoublePair<byte[]> findNearestNeighbour(byte[] query) { 244 final Stack<KDNode> stack = walkdown(query); 245 final ObjectDoublePair<byte[]> state = new ObjectDoublePair<byte[]>(); 246 state.first = stack.peek().point; 247 state.second = distance(query, state.first); 248 249 if (state.second == 0) 250 return state; 251 252 while (!stack.isEmpty()) { 253 final KDNode current = stack.pop(); 254 255 checkSubtree(current, query, state); 256 } 257 258 return state; 259 } 260 261 /** 262 * Find the K nearest neighbours. 263 * 264 * @param query 265 * query coordinate 266 * @param k 267 * the number of neighbours to find 268 * @return nearest neighbours 269 */ 270 public List<ObjectDoublePair<byte[]>> findNearestNeighbours(byte[] query, int k) { 271 final Stack<KDNode> stack = walkdown(query); 272 final BoundedPriorityQueue<ObjectDoublePair<byte[]>> state = new BoundedPriorityQueue<ObjectDoublePair<byte[]>>( 273 k, ObjectDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR); 274 275 final ObjectDoublePair<byte[]> initialState = new ObjectDoublePair<byte[]>(); 276 initialState.first = stack.peek().point; 277 initialState.second = distance(query, initialState.first); 278 state.add(initialState); 279 280 while (!stack.isEmpty()) { 281 final KDNode current = stack.pop(); 282 283 checkSubtreeK(current, query, state, k); 284 } 285 286 return state.toOrderedListDestructive(); 287 } 288 289 /* 290 * Check a subtree for a closer match 291 */ 292 private void checkSubtree(KDNode node, byte[] query, ObjectDoublePair<byte[]> state) { 293 if (node == null) 294 return; 295 296 final double dist = distance(query, node.point); 297 if (dist < state.second) { 298 state.first = node.point; 299 state.second = dist; 300 } 301 302 if (state.second == 0) 303 return; 304 305 final double d = node.point[node.discriminateDim] - query[node.discriminateDim]; 306 if (d * d > state.second) { 307 // check subtree 308 final double ordinate1 = query[node.discriminateDim]; 309 final double ordinate2 = node.point[node.discriminateDim]; 310 311 if (ordinate1 > ordinate2) 312 checkSubtree(node.right, query, state); 313 else 314 checkSubtree(node.left, query, state); 315 } else { 316 checkSubtree(node.left, query, state); 317 checkSubtree(node.right, query, state); 318 } 319 } 320 321 private void checkSubtreeK(KDNode node, byte[] query, PriorityQueue<ObjectDoublePair<byte[]>> state, int k) { 322 if (node == null) 323 return; 324 325 final double dist = distance(query, node.point); 326 327 boolean cont = false; 328 for (final ObjectDoublePair<byte[]> s : state) 329 if (s.first.equals(node.point)) { 330 cont = true; 331 break; 332 } 333 334 if (!cont) { 335 if (state.size() < k) { 336 // collect this node 337 final ObjectDoublePair<byte[]> s = new ObjectDoublePair<byte[]>(); 338 s.first = node.point; 339 s.second = dist; 340 state.add(s); 341 } else if (dist < state.peek().second) { 342 // replace last node 343 final ObjectDoublePair<byte[]> s = state.poll(); 344 s.first = node.point; 345 s.second = dist; 346 state.add(s); 347 } 348 } 349 350 final double d = node.point[node.discriminateDim] - query[node.discriminateDim]; 351 if (d * d > state.peek().second) { 352 // check subtree 353 final double ordinate1 = query[node.discriminateDim]; 354 final double ordinate2 = node.point[node.discriminateDim]; 355 356 if (ordinate1 > ordinate2) 357 checkSubtreeK(node.right, query, state, k); 358 else 359 checkSubtreeK(node.left, query, state, k); 360 } else { 361 checkSubtreeK(node.left, query, state, k); 362 checkSubtreeK(node.right, query, state, k); 363 } 364 } 365 366 /* 367 * walk down the tree until we hit a leaf, and return the path taken 368 */ 369 private Stack<KDNode> walkdown(byte[] point) { 370 if (_root == null) 371 return null; 372 else { 373 final Stack<KDNode> stack = new Stack<KDNode>(); 374 int discriminate; 375 KDNode curNode, tmpNode; 376 double ordinate1, ordinate2; 377 378 curNode = _root; 379 380 do { 381 tmpNode = curNode; 382 stack.push(tmpNode); 383 if (tmpNode.point == point) 384 return stack; 385 discriminate = tmpNode.discriminateDim; 386 387 ordinate1 = point[discriminate]; 388 ordinate2 = tmpNode.point[discriminate]; 389 390 if (ordinate1 > ordinate2) 391 curNode = tmpNode.right; 392 else 393 curNode = tmpNode.left; 394 } while (curNode != null); 395 396 if (++discriminate >= point.length) 397 discriminate = 0; 398 399 return stack; 400 } 401 } 402 403 /** 404 * Find all the points within the given radius of the given point 405 * 406 * @param centre 407 * the centre point 408 * @param radius 409 * the radius 410 * @return the points 411 */ 412 public List<byte[]> radiusSearch(byte[] centre, byte radius) { 413 final byte[] lower = centre.clone(); 414 final byte[] upper = centre.clone(); 415 416 for (int i = 0; i < centre.length; i++) { 417 lower[i] -= radius; 418 upper[i] += radius; 419 } 420 421 final List<byte[]> rangeList = rangeSearch(lower, upper); 422 final List<byte[]> radiusList = new ArrayList<byte[]>(rangeList.size()); 423 final double radSq = radius * radius; 424 for (final byte[] r : rangeList) { 425 if (distance(centre, r) < radSq) 426 radiusList.add(r); 427 } 428 429 return radiusList; 430 } 431 432 /** 433 * Find all the points within the given radius of the given point. 434 * Returns the distance to the point as well as the point itself. Distance 435 * is the squared L2 distance. 436 * 437 * @param centre 438 * the centre point 439 * @param radius 440 * the radius 441 * @return the points and distances 442 */ 443 public List<ObjectDoublePair<byte[]>> radiusDistanceSearch(byte[] centre, byte radius) { 444 final byte[] lower = centre.clone(); 445 final byte[] upper = centre.clone(); 446 447 for (int i = 0; i < centre.length; i++) { 448 lower[i] -= radius; 449 upper[i] += radius; 450 } 451 452 final List<byte[]> rangeList = rangeSearch(lower, upper); 453 final List<ObjectDoublePair<byte[]>> radiusList = new ArrayList<ObjectDoublePair<byte[]>>(rangeList.size()); 454 final double radSq = radius * radius; 455 for (final byte[] r : rangeList) { 456 double dist = distance(centre, r); 457 if (dist < radSq) 458 radiusList.add(new ObjectDoublePair<byte[]>(r, dist)); 459 } 460 461 return radiusList; 462 } 463}