001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.knn.lsh; 031 032import gnu.trove.list.array.TIntArrayList; 033import gnu.trove.map.hash.TIntObjectHashMap; 034import gnu.trove.set.hash.TIntHashSet; 035 036import java.util.AbstractList; 037import java.util.ArrayList; 038import java.util.Collection; 039import java.util.List; 040 041import org.openimaj.knn.IncrementalNearestNeighbours; 042import org.openimaj.util.comparator.DistanceComparator; 043import org.openimaj.util.hash.HashFunction; 044import org.openimaj.util.hash.HashFunctionFactory; 045import org.openimaj.util.pair.IntFloatPair; 046import org.openimaj.util.queue.BoundedPriorityQueue; 047 048/** 049 * Nearest-neighbours based on Locality Sensitive Hashing (LSH). A number of 050 * internal hash-tables are created with an associated hash-code (which is 051 * usually a composite of multiple locality sensitive hashes). For a given 052 * query, the hash-code of the query object computed for each hash function and 053 * a lookup is made in each table. The set of matching items drawn from the 054 * tables is then combined and sorted by distance (and trimmed if necessary) 055 * before being returned. 056 * <p> 057 * Note: This object is not thread-safe. Multiple insertions or mixed insertions 058 * and searches should not be performed concurrently without external locking. 059 * 060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 061 * 062 * @param <OBJECT> 063 * Type of object being stored. 064 */ 065public class LSHNearestNeighbours<OBJECT> 066 implements 067 IncrementalNearestNeighbours<OBJECT, float[], IntFloatPair> 068{ 069 /** 070 * Encapsulates a hash table with an associated hash function and pointers 071 * to the data. 072 * 073 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 074 * 075 * @param <OBJECT> 076 * Type of object being hashed 077 */ 078 private static class Table<OBJECT> { 079 private final TIntObjectHashMap<TIntArrayList> table; 080 HashFunction<OBJECT> function; 081 082 public Table(HashFunction<OBJECT> function) { 083 this.function = function; 084 table = new TIntObjectHashMap<TIntArrayList>(); 085 } 086 087 /** 088 * Insert a single point 089 * 090 * @param point 091 * the point 092 * @param pid 093 * the id of the point in the data 094 */ 095 protected void insertPoint(OBJECT point, int pid) { 096 final int hash = function.computeHashCode(point); 097 098 TIntArrayList bucket = table.get(hash); 099 if (bucket == null) { 100 table.put(hash, bucket = new TIntArrayList()); 101 } 102 103 bucket.add(pid); 104 } 105 106 /** 107 * Search for a point in the table 108 * 109 * @param point 110 * query point 111 * @param norm 112 * normalisation factor 113 * @return ids of matched points 114 */ 115 protected TIntArrayList searchPoint(OBJECT point) { 116 final int hash = function.computeHashCode(point); 117 118 return table.get(hash); 119 } 120 } 121 122 protected DistanceComparator<OBJECT> distanceFcn; 123 protected List<Table<OBJECT>> tables; 124 protected List<OBJECT> data = new ArrayList<OBJECT>(); 125 126 /** 127 * Construct with the given hash functions and distance function. One table 128 * will be created per hash function. 129 * 130 * @param tableHashes 131 * The hash functions 132 * @param distanceFcn 133 * The distance function 134 */ 135 public LSHNearestNeighbours(List<HashFunction<OBJECT>> tableHashes, DistanceComparator<OBJECT> distanceFcn) { 136 final int numTables = tableHashes.size(); 137 this.distanceFcn = distanceFcn; 138 this.tables = new ArrayList<Table<OBJECT>>(numTables); 139 140 for (int i = 0; i < numTables; i++) { 141 tables.add(new Table<OBJECT>(tableHashes.get(i))); 142 } 143 } 144 145 /** 146 * Construct with the given hash function factory which will be used to 147 * initialize the requested number of hash tables. 148 * 149 * @param factory 150 * The hash function factory. 151 * @param numTables 152 * The number of requested tables. 153 * @param distanceFcn 154 * The distance function. 155 */ 156 public LSHNearestNeighbours(HashFunctionFactory<OBJECT> factory, int numTables, DistanceComparator<OBJECT> distanceFcn) 157 { 158 this.distanceFcn = distanceFcn; 159 this.tables = new ArrayList<Table<OBJECT>>(numTables); 160 161 for (int i = 0; i < numTables; i++) { 162 tables.add(new Table<OBJECT>(factory.create())); 163 } 164 } 165 166 /** 167 * Get the number of hash tables 168 * 169 * @return The number of hash tables 170 */ 171 public int numTables() { 172 return tables.size(); 173 } 174 175 /** 176 * Insert data into the tables 177 * 178 * @param d 179 * the data 180 */ 181 public void addAll(Collection<OBJECT> d) { 182 int i = this.data.size(); 183 184 for (final OBJECT point : d) { 185 this.data.add(point); 186 187 for (final Table<OBJECT> table : tables) { 188 table.insertPoint(point, i); 189 } 190 191 i++; 192 } 193 } 194 195 /** 196 * Insert data into the tables 197 * 198 * @param d 199 * the data 200 */ 201 public void addAll(OBJECT[] d) { 202 int i = this.data.size(); 203 204 for (final OBJECT point : d) { 205 this.data.add(point); 206 207 for (final Table<OBJECT> table : tables) { 208 table.insertPoint(point, i); 209 } 210 211 i++; 212 } 213 } 214 215 @Override 216 public int add(OBJECT o) { 217 final int index = this.data.size(); 218 this.data.add(o); 219 220 for (final Table<OBJECT> table : tables) { 221 table.insertPoint(o, index); 222 } 223 224 return index; 225 } 226 227 /** 228 * Search for similar data in the underlying tables and return all matches 229 * 230 * @param data 231 * the points 232 * @return matched ids 233 */ 234 public TIntHashSet[] search(OBJECT[] data) { 235 final TIntHashSet[] pls = new TIntHashSet[data.length]; 236 237 for (int i = 0; i < data.length; i++) { 238 pls[i] = search(data[i]); 239 } 240 241 return pls; 242 } 243 244 /** 245 * Search for a similar data item in the underlying tables and return all 246 * matches 247 * 248 * @param data 249 * the point 250 * @return matched ids 251 */ 252 public TIntHashSet search(OBJECT data) { 253 final TIntHashSet pl = new TIntHashSet(); 254 255 for (final Table<OBJECT> table : tables) { 256 final TIntArrayList result = table.searchPoint(data); 257 258 if (result != null) 259 pl.addAll(result); 260 } 261 262 return pl; 263 } 264 265 /** 266 * Compute identifiers of the buckets in which the given points belong for 267 * all the tables. 268 * 269 * @param data 270 * the points 271 * @return the bucket identifiers 272 */ 273 public int[][] getBucketId(OBJECT[] data) { 274 final int[][] ids = new int[data.length][]; 275 276 for (int i = 0; i < data.length; i++) { 277 ids[i] = getBucketId(data[i]); 278 } 279 280 return ids; 281 } 282 283 /** 284 * Compute identifiers of the buckets in which the given point belongs for 285 * all the tables. 286 * 287 * @param point 288 * the point 289 * @return the bucket identifiers 290 */ 291 public int[] getBucketId(OBJECT point) { 292 final int[] ids = new int[tables.size()]; 293 294 for (int j = 0; j < tables.size(); j++) { 295 ids[j] = tables.get(j).function.computeHashCode(point); 296 } 297 298 return ids; 299 } 300 301 @Override 302 public void searchNN(OBJECT[] qus, int[] argmins, float[] mins) { 303 final int[][] argminsWrapper = { argmins }; 304 final float[][] minsWrapper = { mins }; 305 306 searchKNN(qus, 1, argminsWrapper, minsWrapper); 307 } 308 309 @Override 310 public void searchKNN(OBJECT[] qus, int K, int[][] argmins, float[][] mins) { 311 // loop on the search data 312 for (int i = 0; i < qus.length; i++) { 313 final TIntHashSet pl = search(qus[i]); 314 315 // now sort the selected points by distance 316 final int[] ids = pl.toArray(); 317 final List<OBJECT> vectors = new ArrayList<OBJECT>(ids.length); 318 for (int j = 0; j < ids.length; j++) { 319 vectors.add(data.get(ids[j])); 320 } 321 322 exactNN(vectors, ids, qus[i], K, argmins[i], mins[i]); 323 } 324 } 325 326 @Override 327 public void searchNN(List<OBJECT> qus, int[] argmins, float[] mins) { 328 final int[][] argminsWrapper = { argmins }; 329 final float[][] minsWrapper = { mins }; 330 331 searchKNN(qus, 1, argminsWrapper, minsWrapper); 332 } 333 334 @Override 335 public void searchKNN(List<OBJECT> qus, int K, int[][] argmins, float[][] mins) { 336 final int size = qus.size(); 337 // loop on the search data 338 for (int i = 0; i < size; i++) { 339 final TIntHashSet pl = search(qus.get(i)); 340 341 // now sort the selected points by distance 342 final int[] ids = pl.toArray(); 343 final List<OBJECT> vectors = new ArrayList<OBJECT>(ids.length); 344 for (int j = 0; j < ids.length; j++) { 345 vectors.add(data.get(ids[j])); 346 } 347 348 exactNN(vectors, ids, qus.get(i), K, argmins[i], mins[i]); 349 } 350 } 351 352 /* 353 * Exact NN on a subset 354 */ 355 private void exactNN(List<OBJECT> subset, int[] ids, OBJECT query, int K, int[] argmins, float[] mins) { 356 final int size = subset.size(); 357 358 // Fix for when the user asks for too many points. 359 final int actualK = Math.min(K, size); 360 361 for (int k = actualK; k < K; k++) { 362 argmins[k] = -1; 363 mins[k] = Float.MAX_VALUE; 364 } 365 366 if (actualK == 0) 367 return; 368 369 final BoundedPriorityQueue<IntFloatPair> queue = 370 new BoundedPriorityQueue<IntFloatPair>(actualK, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 371 372 // prepare working data 373 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(actualK + 1); 374 for (int i = 0; i < actualK + 1; i++) { 375 list.add(new IntFloatPair()); 376 } 377 378 final List<IntFloatPair> result = search(subset, query, queue, list); 379 380 for (int k = 0; k < actualK; ++k) { 381 final IntFloatPair p = result.get(k); 382 argmins[k] = ids[p.first]; 383 mins[k] = p.second; 384 } 385 } 386 387 private List<IntFloatPair> search(List<OBJECT> subset, OBJECT query, BoundedPriorityQueue<IntFloatPair> queue, 388 List<IntFloatPair> results) 389 { 390 final int size = subset.size(); 391 392 IntFloatPair wp = null; 393 // reset all values in the queue to MAX, -1 394 for (final IntFloatPair p : results) { 395 p.second = Float.MAX_VALUE; 396 p.first = -1; 397 wp = queue.offerItem(p); 398 } 399 400 // perform the search 401 for (int i = 0; i < size; i++) { 402 wp.second = (float) distanceFcn.compare(query, subset.get(i)); 403 wp.first = i; 404 wp = queue.offerItem(wp); 405 } 406 407 return queue.toOrderedListDestructive(); 408 } 409 410 @Override 411 public int size() { 412 return data.size(); 413 } 414 415 /** 416 * Get a read-only view of the underlying data. 417 * 418 * @return a read-only view of the underlying data. 419 */ 420 public List<OBJECT> getData() { 421 return new AbstractList<OBJECT>() { 422 423 @Override 424 public OBJECT get(int index) { 425 return data.get(index); 426 } 427 428 @Override 429 public int size() { 430 return data.size(); 431 } 432 }; 433 } 434 435 /** 436 * Get the data item at the given index. 437 * 438 * @param i 439 * The index 440 * @return the retrieved object 441 */ 442 public OBJECT get(int i) { 443 return data.get(i); 444 } 445 446 @Override 447 public int[] addAll(List<OBJECT> d) { 448 final int[] indexes = new int[d.size()]; 449 450 for (int i = 0; i < indexes.length; i++) { 451 indexes[i] = add(d.get(i)); 452 } 453 454 return indexes; 455 } 456 457 @Override 458 public List<IntFloatPair> searchKNN(OBJECT query, int K) { 459 final ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1); 460 qus.add(query); 461 462 final int[][] idx = new int[1][K]; 463 final float[][] dst = new float[1][K]; 464 465 this.searchKNN(qus, K, idx, dst); 466 467 final List<IntFloatPair> res = new ArrayList<IntFloatPair>(); 468 for (int k = 0; k < K; k++) { 469 if (idx[0][k] != -1) 470 res.add(new IntFloatPair(idx[0][k], dst[0][k])); 471 } 472 473 return res; 474 } 475 476 @Override 477 public IntFloatPair searchNN(OBJECT query) { 478 final ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1); 479 qus.add(query); 480 481 final int[] idx = new int[1]; 482 final float[] dst = new float[1]; 483 484 this.searchNN(qus, idx, dst); 485 486 if (idx[0] == -1) 487 return null; 488 489 return new IntFloatPair(idx[0], dst[0]); 490 } 491}