001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/pq/Incremental#T#ADCNearestNeighbours.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034 
035 package org.openimaj.knn.pq;
036
037import java.io.DataInput;
038import java.io.DataOutput;
039import java.io.IOException;
040import java.util.ArrayList;
041import java.util.Arrays;
042import java.util.List;
043
044import org.openimaj.citation.annotation.Reference;
045import org.openimaj.citation.annotation.ReferenceType;
046import org.openimaj.data.DataSource;
047import org.openimaj.io.IOUtils;
048import org.openimaj.io.ReadWriteableBinary;
049import org.openimaj.knn.ShortNearestNeighbours;
050import org.openimaj.knn.IncrementalNearestNeighbours;
051import org.openimaj.util.pair.IntFloatPair;
052import org.openimaj.util.queue.BoundedPriorityQueue;
053
054/**
055 * Incremental Nearest-neighbours using Asymmetric Distance Computation (ADC) 
056 * on Product Quantised vectors. In ADC, only the database points are quantised.
057 * The queries themselves are not quantised. The overall distance is computed
058 * as the summed distance of each subvector of the query to each corresponding
059 * centroids of each database vector.
060 * <p>
061 * For efficiency, the distance of each sub-vector of a query is computed to
062 * every centroid (for the sub-vector under consideration) only once, and is
063 * then cached for the lookup during the computation of the distance to each
064 * database vector.
065 * 
066 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
067 */
068@Reference(
069                type = ReferenceType.Article,
070                author = { "Jegou, Herve", "Douze, Matthijs", "Schmid, Cordelia" },
071                title = "Product Quantization for Nearest Neighbor Search",
072                year = "2011",
073                journal = "IEEE Trans. Pattern Anal. Mach. Intell.",
074                pages = { "117", "", "128" },
075                url = "http://dx.doi.org/10.1109/TPAMI.2010.57",
076                month = "January",
077                number = "1",
078                publisher = "IEEE Computer Society",
079                volume = "33",
080                customData = {
081                                "issn", "0162-8828",
082                                "numpages", "12",
083                                "doi", "10.1109/TPAMI.2010.57",
084                                "acmid", "1916695",
085                                "address", "Washington, DC, USA",
086                                "keywords", "High-dimensional indexing, High-dimensional indexing, image indexing, very large databases, approximate search., approximate search., image indexing, very large databases"
087                })
088public class IncrementalShortADCNearestNeighbours 
089        extends 
090                ShortNearestNeighbours 
091        implements 
092                IncrementalNearestNeighbours<short[], float[], IntFloatPair>,
093                ReadWriteableBinary 
094{
095        protected ShortProductQuantiser pq;
096        protected int ndims;
097        protected List<byte[]> data;
098
099    protected IncrementalShortADCNearestNeighbours() {
100        //for deserialization
101    }
102
103        /**
104         * Construct the ADC with the given quantiser and data points.
105         * 
106         * @param pq
107         *            the Product Quantiser
108         * @param dataPoints
109         *            the data points to index
110         */
111        public IncrementalShortADCNearestNeighbours(ShortProductQuantiser pq, short[][] dataPoints) {
112                this.pq = pq;
113                this.ndims = dataPoints[0].length;
114
115                this.data = new ArrayList<byte[]>(dataPoints.length);
116                for (int i = 0; i < dataPoints.length; i++) {
117                        data.add(pq.quantise(dataPoints[i]));
118                }
119        }
120        
121        /**
122         * Construct the ADC with the given quantiser and data points.
123         * 
124         * @param pq
125         *            the Product Quantiser
126         * @param dataPoints
127         *            the data points to index
128         */
129        public IncrementalShortADCNearestNeighbours(ShortProductQuantiser pq, List<short[]> dataPoints) {
130                this.pq = pq;
131                this.ndims = dataPoints.get(0).length;
132                
133                final int size = dataPoints.size();
134                this.data = new ArrayList<byte[]>(size);
135                for (int i = 0; i < size; i++) {
136                        data.add(pq.quantise(dataPoints.get(i)));
137                }
138        }
139        
140        /**
141         * Construct the ADC with the given quantiser and data points.
142         * 
143         * @param pq
144         *            the Product Quantiser
145         * @param dataPoints
146         *            the data points to index
147         */
148        public IncrementalShortADCNearestNeighbours(ShortProductQuantiser pq, DataSource<short[]> dataPoints) {
149                this.pq = pq;
150                this.ndims = dataPoints.getData(0).length;
151
152                final int size = dataPoints.size();
153                this.data = new ArrayList<byte[]>(size);
154                for (int i = 0; i < size; i++) {
155                        data.add(pq.quantise(dataPoints.getData(i)));
156                }
157        }
158        
159        /**
160         * Construct an empty ADC with the given quantiser.
161         * 
162         * @param pq
163         *            the Product Quantiser
164         * @param ndims
165         *            the data dimensionality
166         */
167        public IncrementalShortADCNearestNeighbours(ShortProductQuantiser pq, int ndims) {
168                this.pq = pq;
169                this.ndims = ndims;
170
171                this.data = new ArrayList<byte[]>();
172        }
173        
174        /**
175         * Construct an empty ADC with the given quantiser.
176         * 
177         * @param pq
178         *            the Product Quantiser
179         * @param ndims
180         *            the data dimensionality
181         * @param nitems
182         *            the expected number of data items
183         */
184        public IncrementalShortADCNearestNeighbours(ShortProductQuantiser pq, int ndims, int nitems) {
185                this.pq = pq;
186                this.ndims = ndims;
187
188                this.data = new ArrayList<byte[]>(nitems);
189        }
190        
191        @Override
192        public int[] addAll(List<short[]> d) {
193                final int[] indexes = new int[d.size()];
194
195                for (int i = 0; i < indexes.length; i++) {
196                        indexes[i] = add(d.get(i));
197                }
198
199                return indexes;
200        }
201
202        @Override
203        public int add(short[] o) {
204                final int ret = data.size();
205                data.add(pq.quantise(o));
206                return ret;
207        }
208
209        @Override
210        public int numDimensions() {
211                return ndims;
212        }
213
214        @Override
215        public int size() {
216                return data.size();
217        }
218        
219        @Override
220        public void readBinary(DataInput in) throws IOException {
221                pq = IOUtils.read(in);
222                ndims = in.readInt();
223
224                int size = in.readInt();
225                int dim = pq.assigners.length;
226                data = new ArrayList<byte[]>(size);
227                for (int i=0; i<size; i++) {
228                        byte[] bytes = new byte[dim];
229                        in.readFully(bytes);
230                        data.add(bytes);
231                }
232        }
233
234        @Override
235        public byte[] binaryHeader() {
236                return "IShortADCNN".getBytes();
237        }
238
239        @Override
240        public void writeBinary(DataOutput out) throws IOException {
241                IOUtils.write(pq, out);
242                out.writeInt(ndims);
243
244                int size = data.size();
245                out.writeInt(size);
246
247                for (int i=0; i<size; i++)
248                        out.write(data.get(i));
249        }
250        
251        @Override
252        public void searchNN(final short [][] qus, int [] indices, float [] distances) {
253                final int N = qus.length;
254                
255                final BoundedPriorityQueue<IntFloatPair> queue =
256                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
257
258        //prepare working data
259                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
260                list.add(new IntFloatPair());
261                list.add(new IntFloatPair());
262                
263                for (int n=0; n < N; ++n) {
264                        List<IntFloatPair> result = search(qus[n], queue, list);
265                        
266                        final IntFloatPair p = result.get(0);
267                        indices[n] = p.first;
268                        distances[n] = p.second;
269                }
270        }
271
272        @Override
273        public void searchKNN(final short [][] qus, int K, int [][] indices, float [][] distances) {
274                // Fix for when the user asks for too many points.
275                K = Math.min(K, data.size());
276
277                final int N = qus.length;
278
279                final BoundedPriorityQueue<IntFloatPair> queue =
280                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
281
282        //prepare working data
283                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
284                for (int i = 0; i < K + 1; i++) {
285                        list.add(new IntFloatPair());
286                }
287
288        // search on each query
289                for (int n = 0; n < N; ++n) {
290                        List<IntFloatPair> result = search(qus[n], queue, list);
291                        
292                        for (int k = 0; k < K; ++k) {
293                                final IntFloatPair p = result.get(k);
294                                indices[n][k] = p.first;
295                                distances[n][k] = p.second;
296                        }
297                }
298        }
299        
300        @Override
301        public void searchNN(final List<short[]> qus, int [] indices, float [] distances) {
302                final int N = qus.size();
303                
304                final BoundedPriorityQueue<IntFloatPair> queue =
305                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
306
307        //prepare working data
308                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
309                list.add(new IntFloatPair());
310                list.add(new IntFloatPair());
311                
312                for (int n=0; n < N; ++n) {
313                        List<IntFloatPair> result = search(qus.get(n), queue, list);
314                        
315                        final IntFloatPair p = result.get(0);
316                        indices[n] = p.first;
317                        distances[n] = p.second;
318                }
319        }
320
321        @Override
322        public void searchKNN(final List<short[]> qus, int K, int [][] indices, float [][] distances) {
323                // Fix for when the user asks for too many points.
324                K = Math.min(K, data.size());
325
326                final int N = qus.size();
327
328                final BoundedPriorityQueue<IntFloatPair> queue =
329                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
330
331        //prepare working data
332                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
333                for (int i = 0; i < K + 1; i++) {
334                        list.add(new IntFloatPair());
335                }
336
337        // search on each query
338                for (int n = 0; n < N; ++n) {
339                        List<IntFloatPair> result = search(qus.get(n), queue, list);
340                        
341                        for (int k = 0; k < K; ++k) {
342                                final IntFloatPair p = result.get(k);
343                                indices[n][k] = p.first;
344                                distances[n][k] = p.second;
345                        }
346                }
347        }
348
349    @Override
350        public List<IntFloatPair> searchKNN(short[] query, int K) {
351                // Fix for when the user asks for too many points.
352                K = Math.min(K, data.size());
353
354                final BoundedPriorityQueue<IntFloatPair> queue =
355                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
356
357        //prepare working data
358                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
359                for (int i = 0; i < K + 1; i++) {
360                        list.add(new IntFloatPair());
361                }
362
363        // search
364        return search(query, queue, list);
365        }
366
367        @Override
368        public IntFloatPair searchNN(final short[] query) {
369                final BoundedPriorityQueue<IntFloatPair> queue =
370                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
371
372        //prepare working data
373                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
374                list.add(new IntFloatPair());
375                list.add(new IntFloatPair());
376                
377                return search(query, queue, list).get(0);
378        }
379
380    private List<IntFloatPair> search(short[] query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results) {
381        IntFloatPair wp = null;
382        
383        // reset all values in the queue to MAX, -1
384                for (final IntFloatPair p : results) {
385                        p.second = Float.MAX_VALUE;
386                        p.first = -1;
387                        wp = queue.offerItem(p);
388                }
389
390        // perform the search
391                computeDistances(query, queue, wp);
392                
393        return queue.toOrderedListDestructive();
394    }
395    
396    protected void computeDistances(short[] fullQuery, BoundedPriorityQueue<IntFloatPair> queue, IntFloatPair wp) {
397                final float[][] distances = new float[pq.assigners.length][];
398
399                for (int j = 0, from = 0; j < this.pq.assigners.length; j++) {
400                        final ShortNearestNeighbours nn = this.pq.assigners[j];
401                        final int to = nn.numDimensions();
402                        final int K = nn.size();
403
404                        final short[][] qus = { Arrays.copyOfRange(fullQuery, from, from + to) };
405                        final int[][] idx = new int[1][K];
406                        final float[][] dst = new float[1][K];
407                        nn.searchKNN(qus, K, idx, dst);
408
409                        distances[j] = new float[K];
410                        for (int k = 0; k < K; k++) {
411                                distances[j][idx[0][k]] = dst[0][k];
412                        }
413
414                        from += to;
415                }
416
417        final int size = data.size();
418                for (int i = 0; i < size; i++) {
419                        wp.first = i;
420                        wp.second = 0;
421
422                        for (int j = 0; j < this.pq.assigners.length; j++) {
423                                final int centroid = this.data.get(i)[j] + 128;
424                                wp.second += distances[j][centroid];
425                        }
426
427                        wp = queue.offerItem(wp);
428                }
429        }
430}