001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.knn.lsh;
031
032import gnu.trove.list.array.TIntArrayList;
033import gnu.trove.map.hash.TIntObjectHashMap;
034import gnu.trove.set.hash.TIntHashSet;
035
036import java.util.AbstractList;
037import java.util.ArrayList;
038import java.util.Collection;
039import java.util.List;
040
041import org.openimaj.knn.IncrementalNearestNeighbours;
042import org.openimaj.util.comparator.DistanceComparator;
043import org.openimaj.util.hash.HashFunction;
044import org.openimaj.util.hash.HashFunctionFactory;
045import org.openimaj.util.pair.IntFloatPair;
046import org.openimaj.util.queue.BoundedPriorityQueue;
047
048/**
049 * Nearest-neighbours based on Locality Sensitive Hashing (LSH). A number of
050 * internal hash-tables are created with an associated hash-code (which is
051 * usually a composite of multiple locality sensitive hashes). For a given
052 * query, the hash-code of the query object computed for each hash function and
053 * a lookup is made in each table. The set of matching items drawn from the
054 * tables is then combined and sorted by distance (and trimmed if necessary)
055 * before being returned.
056 * <p>
057 * Note: This object is not thread-safe. Multiple insertions or mixed insertions
058 * and searches should not be performed concurrently without external locking.
059 *
060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
061 *
062 * @param <OBJECT>
063 *            Type of object being stored.
064 */
065public class LSHNearestNeighbours<OBJECT>
066                implements
067                IncrementalNearestNeighbours<OBJECT, float[], IntFloatPair>
068{
069        /**
070         * Encapsulates a hash table with an associated hash function and pointers
071         * to the data.
072         *
073         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
074         *
075         * @param <OBJECT>
076         *            Type of object being hashed
077         */
078        private static class Table<OBJECT> {
079                private final TIntObjectHashMap<TIntArrayList> table;
080                HashFunction<OBJECT> function;
081
082                public Table(HashFunction<OBJECT> function) {
083                        this.function = function;
084                        table = new TIntObjectHashMap<TIntArrayList>();
085                }
086
087                /**
088                 * Insert a single point
089                 *
090                 * @param point
091                 *            the point
092                 * @param pid
093                 *            the id of the point in the data
094                 */
095                protected void insertPoint(OBJECT point, int pid) {
096                        final int hash = function.computeHashCode(point);
097
098                        TIntArrayList bucket = table.get(hash);
099                        if (bucket == null) {
100                                table.put(hash, bucket = new TIntArrayList());
101                        }
102
103                        bucket.add(pid);
104                }
105
106                /**
107                 * Search for a point in the table
108                 *
109                 * @param point
110                 *            query point
111                 * @param norm
112                 *            normalisation factor
113                 * @return ids of matched points
114                 */
115                protected TIntArrayList searchPoint(OBJECT point) {
116                        final int hash = function.computeHashCode(point);
117
118                        return table.get(hash);
119                }
120        }
121
122        protected DistanceComparator<OBJECT> distanceFcn;
123        protected List<Table<OBJECT>> tables;
124        protected List<OBJECT> data = new ArrayList<OBJECT>();
125
126        /**
127         * Construct with the given hash functions and distance function. One table
128         * will be created per hash function.
129         *
130         * @param tableHashes
131         *            The hash functions
132         * @param distanceFcn
133         *            The distance function
134         */
135        public LSHNearestNeighbours(List<HashFunction<OBJECT>> tableHashes, DistanceComparator<OBJECT> distanceFcn) {
136                final int numTables = tableHashes.size();
137                this.distanceFcn = distanceFcn;
138                this.tables = new ArrayList<Table<OBJECT>>(numTables);
139
140                for (int i = 0; i < numTables; i++) {
141                        tables.add(new Table<OBJECT>(tableHashes.get(i)));
142                }
143        }
144
145        /**
146         * Construct with the given hash function factory which will be used to
147         * initialize the requested number of hash tables.
148         *
149         * @param factory
150         *            The hash function factory.
151         * @param numTables
152         *            The number of requested tables.
153         * @param distanceFcn
154         *            The distance function.
155         */
156        public LSHNearestNeighbours(HashFunctionFactory<OBJECT> factory, int numTables, DistanceComparator<OBJECT> distanceFcn)
157        {
158                this.distanceFcn = distanceFcn;
159                this.tables = new ArrayList<Table<OBJECT>>(numTables);
160
161                for (int i = 0; i < numTables; i++) {
162                        tables.add(new Table<OBJECT>(factory.create()));
163                }
164        }
165
166        /**
167         * Get the number of hash tables
168         *
169         * @return The number of hash tables
170         */
171        public int numTables() {
172                return tables.size();
173        }
174
175        /**
176         * Insert data into the tables
177         *
178         * @param d
179         *            the data
180         */
181        public void addAll(Collection<OBJECT> d) {
182                int i = this.data.size();
183
184                for (final OBJECT point : d) {
185                        this.data.add(point);
186
187                        for (final Table<OBJECT> table : tables) {
188                                table.insertPoint(point, i);
189                        }
190
191                        i++;
192                }
193        }
194
195        /**
196         * Insert data into the tables
197         *
198         * @param d
199         *            the data
200         */
201        public void addAll(OBJECT[] d) {
202                int i = this.data.size();
203
204                for (final OBJECT point : d) {
205                        this.data.add(point);
206
207                        for (final Table<OBJECT> table : tables) {
208                                table.insertPoint(point, i);
209                        }
210
211                        i++;
212                }
213        }
214
215        @Override
216        public int add(OBJECT o) {
217                final int index = this.data.size();
218                this.data.add(o);
219
220                for (final Table<OBJECT> table : tables) {
221                        table.insertPoint(o, index);
222                }
223
224                return index;
225        }
226
227        /**
228         * Search for similar data in the underlying tables and return all matches
229         *
230         * @param data
231         *            the points
232         * @return matched ids
233         */
234        public TIntHashSet[] search(OBJECT[] data) {
235                final TIntHashSet[] pls = new TIntHashSet[data.length];
236
237                for (int i = 0; i < data.length; i++) {
238                        pls[i] = search(data[i]);
239                }
240
241                return pls;
242        }
243
244        /**
245         * Search for a similar data item in the underlying tables and return all
246         * matches
247         *
248         * @param data
249         *            the point
250         * @return matched ids
251         */
252        public TIntHashSet search(OBJECT data) {
253                final TIntHashSet pl = new TIntHashSet();
254
255                for (final Table<OBJECT> table : tables) {
256                        final TIntArrayList result = table.searchPoint(data);
257
258                        if (result != null)
259                                pl.addAll(result);
260                }
261
262                return pl;
263        }
264
265        /**
266         * Compute identifiers of the buckets in which the given points belong for
267         * all the tables.
268         *
269         * @param data
270         *            the points
271         * @return the bucket identifiers
272         */
273        public int[][] getBucketId(OBJECT[] data) {
274                final int[][] ids = new int[data.length][];
275
276                for (int i = 0; i < data.length; i++) {
277                        ids[i] = getBucketId(data[i]);
278                }
279
280                return ids;
281        }
282
283        /**
284         * Compute identifiers of the buckets in which the given point belongs for
285         * all the tables.
286         *
287         * @param point
288         *            the point
289         * @return the bucket identifiers
290         */
291        public int[] getBucketId(OBJECT point) {
292                final int[] ids = new int[tables.size()];
293
294                for (int j = 0; j < tables.size(); j++) {
295                        ids[j] = tables.get(j).function.computeHashCode(point);
296                }
297
298                return ids;
299        }
300
301        @Override
302        public void searchNN(OBJECT[] qus, int[] argmins, float[] mins) {
303                final int[][] argminsWrapper = { argmins };
304                final float[][] minsWrapper = { mins };
305
306                searchKNN(qus, 1, argminsWrapper, minsWrapper);
307        }
308
309        @Override
310        public void searchKNN(OBJECT[] qus, int K, int[][] argmins, float[][] mins) {
311                // loop on the search data
312                for (int i = 0; i < qus.length; i++) {
313                        final TIntHashSet pl = search(qus[i]);
314
315                        // now sort the selected points by distance
316                        final int[] ids = pl.toArray();
317                        final List<OBJECT> vectors = new ArrayList<OBJECT>(ids.length);
318                        for (int j = 0; j < ids.length; j++) {
319                                vectors.add(data.get(ids[j]));
320                        }
321
322                        exactNN(vectors, ids, qus[i], K, argmins[i], mins[i]);
323                }
324        }
325
326        @Override
327        public void searchNN(List<OBJECT> qus, int[] argmins, float[] mins) {
328                final int[][] argminsWrapper = { argmins };
329                final float[][] minsWrapper = { mins };
330
331                searchKNN(qus, 1, argminsWrapper, minsWrapper);
332        }
333
334        @Override
335        public void searchKNN(List<OBJECT> qus, int K, int[][] argmins, float[][] mins) {
336                final int size = qus.size();
337                // loop on the search data
338                for (int i = 0; i < size; i++) {
339                        final TIntHashSet pl = search(qus.get(i));
340
341                        // now sort the selected points by distance
342                        final int[] ids = pl.toArray();
343                        final List<OBJECT> vectors = new ArrayList<OBJECT>(ids.length);
344                        for (int j = 0; j < ids.length; j++) {
345                                vectors.add(data.get(ids[j]));
346                        }
347
348                        exactNN(vectors, ids, qus.get(i), K, argmins[i], mins[i]);
349                }
350        }
351
352        /*
353         * Exact NN on a subset
354         */
355        private void exactNN(List<OBJECT> subset, int[] ids, OBJECT query, int K, int[] argmins, float[] mins) {
356                final int size = subset.size();
357
358                // Fix for when the user asks for too many points.
359                final int actualK = Math.min(K, size);
360
361                for (int k = actualK; k < K; k++) {
362                        argmins[k] = -1;
363                        mins[k] = Float.MAX_VALUE;
364                }
365
366                if (actualK == 0)
367                        return;
368
369                final BoundedPriorityQueue<IntFloatPair> queue =
370                                new BoundedPriorityQueue<IntFloatPair>(actualK, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
371
372                // prepare working data
373                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(actualK + 1);
374                for (int i = 0; i < actualK + 1; i++) {
375                        list.add(new IntFloatPair());
376                }
377
378                final List<IntFloatPair> result = search(subset, query, queue, list);
379
380                for (int k = 0; k < actualK; ++k) {
381                        final IntFloatPair p = result.get(k);
382                        argmins[k] = ids[p.first];
383                        mins[k] = p.second;
384                }
385        }
386
387        private List<IntFloatPair> search(List<OBJECT> subset, OBJECT query, BoundedPriorityQueue<IntFloatPair> queue,
388                        List<IntFloatPair> results)
389        {
390                final int size = subset.size();
391
392                IntFloatPair wp = null;
393                // reset all values in the queue to MAX, -1
394                for (final IntFloatPair p : results) {
395                        p.second = Float.MAX_VALUE;
396                        p.first = -1;
397                        wp = queue.offerItem(p);
398                }
399
400                // perform the search
401                for (int i = 0; i < size; i++) {
402                        wp.second = (float) distanceFcn.compare(query, subset.get(i));
403                        wp.first = i;
404                        wp = queue.offerItem(wp);
405                }
406
407                return queue.toOrderedListDestructive();
408        }
409
410        @Override
411        public int size() {
412                return data.size();
413        }
414
415        /**
416         * Get a read-only view of the underlying data.
417         *
418         * @return a read-only view of the underlying data.
419         */
420        public List<OBJECT> getData() {
421                return new AbstractList<OBJECT>() {
422
423                        @Override
424                        public OBJECT get(int index) {
425                                return data.get(index);
426                        }
427
428                        @Override
429                        public int size() {
430                                return data.size();
431                        }
432                };
433        }
434
435        /**
436         * Get the data item at the given index.
437         *
438         * @param i
439         *            The index
440         * @return the retrieved object
441         */
442        public OBJECT get(int i) {
443                return data.get(i);
444        }
445
446        @Override
447        public int[] addAll(List<OBJECT> d) {
448                final int[] indexes = new int[d.size()];
449
450                for (int i = 0; i < indexes.length; i++) {
451                        indexes[i] = add(d.get(i));
452                }
453
454                return indexes;
455        }
456
457        @Override
458        public List<IntFloatPair> searchKNN(OBJECT query, int K) {
459                final ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1);
460                qus.add(query);
461
462                final int[][] idx = new int[1][K];
463                final float[][] dst = new float[1][K];
464
465                this.searchKNN(qus, K, idx, dst);
466
467                final List<IntFloatPair> res = new ArrayList<IntFloatPair>();
468                for (int k = 0; k < K; k++) {
469                        if (idx[0][k] != -1)
470                                res.add(new IntFloatPair(idx[0][k], dst[0][k]));
471                }
472
473                return res;
474        }
475
476        @Override
477        public IntFloatPair searchNN(OBJECT query) {
478                final ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1);
479                qus.add(query);
480
481                final int[] idx = new int[1];
482                final float[] dst = new float[1];
483
484                this.searchNN(qus, idx, dst);
485
486                if (idx[0] == -1)
487                        return null;
488
489                return new IntFloatPair(idx[0], dst[0]);
490        }
491}