001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/#T#NearestNeighboursExact.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034package org.openimaj.knn;
035
036import java.util.ArrayList;
037import java.util.List;
038
039import org.openimaj.feature.FloatFVComparison;
040import org.openimaj.feature.FloatFVComparator;
041import org.openimaj.util.pair.IntFloatPair;
042import org.openimaj.util.queue.BoundedPriorityQueue;
043
044/**
045 * Exact (brute-force) k-nearest-neighbour implementation.
046 * 
047 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
048 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
049 */
050public class FloatNearestNeighboursExact extends FloatNearestNeighbours {
051    /**
052         * {@link NearestNeighboursFactory} for producing
053         * {@link FloatNearestNeighboursExact}s.
054         * 
055         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
056         */
057    public static final class Factory implements NearestNeighboursFactory<FloatNearestNeighboursExact, float[]> {
058        private final FloatFVComparator distance;
059        
060        /**
061         * Construct the factory using Euclidean distance for the 
062         * produced FloatNearestNeighbours instances.
063         */
064        public Factory() {
065            this.distance = null;
066        }
067        
068        /**
069         * Construct the factory with the given distance function
070         * for the produced FloatNearestNeighbours instances.
071         * 
072                 * @param distance
073                 *            the distance function
074         */
075        public Factory(FloatFVComparator distance) {
076            this.distance = distance;
077        }
078        
079        @Override
080        public FloatNearestNeighboursExact create(float[][] data) {
081            return new FloatNearestNeighboursExact(data, distance);
082        }
083    }
084    
085        protected final float[][] pnts;
086        protected final FloatFVComparator distance;
087
088        /**
089         * Construct the FloatNearestNeighboursExact over the provided
090         * dataset and using Euclidean distance.
091         * @param pnts the dataset
092         */
093        public FloatNearestNeighboursExact(final float [][] pnts) {
094                this(pnts, null);
095        }
096
097        /**
098         * Construct the FloatNearestNeighboursExact over the provided
099         * dataset with the given distance function. 
100         * <p>
101         * Note: If the distance function provides similarities rather
102         * than distances they are automatically inverted.
103         *  
104         * @param pnts the dataset
105         * @param distance the distance function
106         */
107        public FloatNearestNeighboursExact(final float [][] pnts, final FloatFVComparator distance) {
108                this.pnts = pnts;
109                this.distance = distance;
110        }
111        
112        @Override
113        public void searchNN(final float [][] qus, int [] indices, float [] distances) {
114                final int N = qus.length;
115                
116                final BoundedPriorityQueue<IntFloatPair> queue =
117                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
118
119        //prepare working data
120                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
121                list.add(new IntFloatPair());
122                list.add(new IntFloatPair());
123                
124                for (int n=0; n < N; ++n) {
125                        List<IntFloatPair> result = search(qus[n], queue, list);
126                        
127                        final IntFloatPair p = result.get(0);
128                        indices[n] = p.first;
129                        distances[n] = p.second;
130                }
131        }
132
133        @Override
134        public void searchKNN(final float [][] qus, int K, int [][] indices, float [][] distances) {
135                // Fix for when the user asks for too many points.
136                K = Math.min(K, pnts.length);
137
138                final int N = qus.length;
139
140                final BoundedPriorityQueue<IntFloatPair> queue =
141                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
142
143        //prepare working data
144                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
145                for (int i = 0; i < K + 1; i++) {
146                        list.add(new IntFloatPair());
147                }
148
149        // search on each query
150                for (int n = 0; n < N; ++n) {
151                        List<IntFloatPair> result = search(qus[n], queue, list);
152                        
153                        for (int k = 0; k < K; ++k) {
154                                final IntFloatPair p = result.get(k);
155                                indices[n][k] = p.first;
156                                distances[n][k] = p.second;
157                        }
158                }
159        }
160        
161        @Override
162        public void searchNN(final List<float[]> qus, int [] indices, float [] distances) {
163                final int N = qus.size();
164                
165                final BoundedPriorityQueue<IntFloatPair> queue =
166                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
167
168        //prepare working data
169                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
170                list.add(new IntFloatPair());
171                list.add(new IntFloatPair());
172                
173                for (int n=0; n < N; ++n) {
174                        List<IntFloatPair> result = search(qus.get(n), queue, list);
175                        
176                        final IntFloatPair p = result.get(0);
177                        indices[n] = p.first;
178                        distances[n] = p.second;
179                }
180        }
181
182        @Override
183        public void searchKNN(final List<float[]> qus, int K, int [][] indices, float [][] distances) {
184                // Fix for when the user asks for too many points.
185                K = Math.min(K, pnts.length);
186
187                final int N = qus.size();
188
189                final BoundedPriorityQueue<IntFloatPair> queue =
190                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
191
192        //prepare working data
193                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
194                for (int i = 0; i < K + 1; i++) {
195                        list.add(new IntFloatPair());
196                }
197
198        // search on each query
199                for (int n = 0; n < N; ++n) {
200                        List<IntFloatPair> result = search(qus.get(n), queue, list);
201                        
202                        for (int k = 0; k < K; ++k) {
203                                final IntFloatPair p = result.get(k);
204                                indices[n][k] = p.first;
205                                distances[n][k] = p.second;
206                        }
207                }
208        }
209
210    @Override
211        public List<IntFloatPair> searchKNN(float[] query, int K) {
212                // Fix for when the user asks for too many points.
213                K = Math.min(K, pnts.length);
214
215                final BoundedPriorityQueue<IntFloatPair> queue =
216                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
217
218        //prepare working data
219                List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
220                for (int i = 0; i < K + 1; i++) {
221                        list.add(new IntFloatPair());
222                }
223
224        // search
225        return search(query, queue, list);
226        }
227
228        @Override
229        public IntFloatPair searchNN(final float[] query) {
230                final BoundedPriorityQueue<IntFloatPair> queue =
231                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
232
233        //prepare working data
234                List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
235                list.add(new IntFloatPair());
236                list.add(new IntFloatPair());
237                
238                return search(query, queue, list).get(0);
239        }
240
241    private List<IntFloatPair> search(float[] query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results) {
242        IntFloatPair wp = null;
243        
244        // reset all values in the queue to MAX, -1
245                for (final IntFloatPair p : results) {
246                        p.second = Float.MAX_VALUE;
247                        p.first = -1;
248                        wp = queue.offerItem(p);
249                }
250
251        // perform the search
252                for (int i = 0; i < this.pnts.length; i++) {
253                        wp.second = distanceFunc(distance, query, pnts[i]);
254                        wp.first = i;
255                        wp = queue.offerItem(wp);
256                }
257                
258        return queue.toOrderedListDestructive();
259    }
260
261        @Override
262        public int numDimensions() {
263                return pnts[0].length;
264        }
265
266        @Override
267        public int size() {
268                return pnts.length;
269        }
270        
271        /**
272         * Get the underlying data points.
273         * 
274         * @return the data points
275         */
276        public float[][] getPoints() {
277                return this.pnts;
278        }
279
280        /**
281         * Compute the distance between two vectors using the underlying distance
282         * comparison used by this class.
283         * 
284         * @param a
285         *            the first vector
286         * @param b
287         *            the second vector
288         * @return the distance between the two vectors
289         */
290        public float computeDistance(float[] a, float[] b) {
291                if (distance == null)
292                        return (float) FloatFVComparison.SUM_SQUARE.compare(a, b);
293                return (float) distance.compare(a, b);
294        }
295
296        /**
297         * Get the distance comparator
298         * 
299         * @return the distance comparator
300         */
301        public FloatFVComparator distanceComparator() {
302                return this.distance;
303        }
304}