001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.knn;
031
032import java.util.ArrayList;
033import java.util.Arrays;
034import java.util.List;
035
036import org.openimaj.util.comparator.DistanceComparator;
037import org.openimaj.util.pair.IntFloatPair;
038import org.openimaj.util.queue.BoundedPriorityQueue;
039
040/**
041 * Exact (brute-force) k-nearest-neighbour implementation for objects with a
042 * compatible {@link DistanceComparator}.
043 *
044 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
045 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
046 *
047 * @param <T>
048 *            Type of object being compared.
049 */
050public class ObjectNearestNeighboursExact<T> extends ObjectNearestNeighbours<T>
051                implements
052                IncrementalNearestNeighbours<T, float[], IntFloatPair>
053{
054        /**
055         * {@link NearestNeighboursFactory} for producing
056         * {@link ObjectNearestNeighboursExact}s.
057         *
058         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
059         *
060         * @param <T>
061         *            Type of object being compared.
062         */
063        public static final class Factory<T> implements NearestNeighboursFactory<ObjectNearestNeighboursExact<T>, T> {
064                private final DistanceComparator<? super T> distance;
065
066                /**
067                 * Construct the factory with the given distance function for the
068                 * produced ObjectNearestNeighbours instances.
069                 *
070                 * @param distance
071                 *            the distance function
072                 */
073                public Factory(DistanceComparator<? super T> distance) {
074                        this.distance = distance;
075                }
076
077                @Override
078                public ObjectNearestNeighboursExact<T> create(T[] data) {
079                        return new ObjectNearestNeighboursExact<T>(data, distance);
080                }
081        }
082
083        protected final List<T> pnts;
084
085        /**
086         * Construct the {@link ObjectNearestNeighboursExact} over the provided
087         * dataset with the given distance function.
088         * <p>
089         * Note: If the distance function provides similarities rather than
090         * distances they are automatically inverted.
091         *
092         * @param pnts
093         *            the dataset
094         * @param distance
095         *            the distance function
096         */
097        public ObjectNearestNeighboursExact(final List<T> pnts, final DistanceComparator<? super T> distance) {
098                super(distance);
099                this.pnts = pnts;
100        }
101
102        /**
103         * Construct the {@link ObjectNearestNeighboursExact} over the provided
104         * dataset with the given distance function.
105         * <p>
106         * Note: If the distance function provides similarities rather than
107         * distances they are automatically inverted.
108         *
109         * @param pnts
110         *            the dataset
111         * @param distance
112         *            the distance function
113         */
114        public ObjectNearestNeighboursExact(final T[] pnts, final DistanceComparator<? super T> distance) {
115                super(distance);
116                this.pnts = Arrays.asList(pnts);
117        }
118
119        /**
120         * Construct any empty {@link ObjectNearestNeighboursExact} with the given
121         * distance function.
122         * <p>
123         * Note: If the distance function provides similarities rather than
124         * distances they are automatically inverted.
125         *
126         * @param distance
127         *            the distance function
128         */
129        public ObjectNearestNeighboursExact(final DistanceComparator<T> distance) {
130                super(distance);
131                this.pnts = new ArrayList<T>();
132        }
133
134        @Override
135        public void searchNN(final T[] qus, int[] indices, float[] distances) {
136                final int N = qus.length;
137
138                final BoundedPriorityQueue<IntFloatPair> queue =
139                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
140
141                // prepare working data
142                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
143                list.add(new IntFloatPair());
144                list.add(new IntFloatPair());
145
146                for (int n = 0; n < N; ++n) {
147                        final List<IntFloatPair> result = search(qus[n], queue, list);
148
149                        final IntFloatPair p = result.get(0);
150                        indices[n] = p.first;
151                        distances[n] = p.second;
152                }
153        }
154
155        @Override
156        public void searchKNN(final T[] qus, int K, int[][] indices, float[][] distances) {
157                // Fix for when the user asks for too many points.
158                K = Math.min(K, pnts.size());
159
160                final int N = qus.length;
161
162                final BoundedPriorityQueue<IntFloatPair> queue =
163                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
164
165                // prepare working data
166                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
167                for (int i = 0; i < K + 1; i++) {
168                        list.add(new IntFloatPair());
169                }
170
171                // search on each query
172                for (int n = 0; n < N; ++n) {
173                        final List<IntFloatPair> result = search(qus[n], queue, list);
174
175                        for (int k = 0; k < K; ++k) {
176                                final IntFloatPair p = result.get(k);
177                                indices[n][k] = p.first;
178                                distances[n][k] = p.second;
179                        }
180                }
181        }
182
183        @Override
184        public void searchNN(final List<T> qus, int[] indices, float[] distances) {
185                final int N = qus.size();
186
187                final BoundedPriorityQueue<IntFloatPair> queue =
188                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
189
190                // prepare working data
191                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
192                list.add(new IntFloatPair());
193                list.add(new IntFloatPair());
194
195                for (int n = 0; n < N; ++n) {
196                        final List<IntFloatPair> result = search(qus.get(n), queue, list);
197
198                        final IntFloatPair p = result.get(0);
199                        indices[n] = p.first;
200                        distances[n] = p.second;
201                }
202        }
203
204        @Override
205        public void searchKNN(final List<T> qus, int K, int[][] indices, float[][] distances) {
206                // Fix for when the user asks for too many points.
207                K = Math.min(K, pnts.size());
208
209                final int N = qus.size();
210
211                final BoundedPriorityQueue<IntFloatPair> queue =
212                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
213
214                // prepare working data
215                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
216                for (int i = 0; i < K + 1; i++) {
217                        list.add(new IntFloatPair());
218                }
219
220                // search on each query
221                for (int n = 0; n < N; ++n) {
222                        final List<IntFloatPair> result = search(qus.get(n), queue, list);
223
224                        for (int k = 0; k < K; ++k) {
225                                final IntFloatPair p = result.get(k);
226                                indices[n][k] = p.first;
227                                distances[n][k] = p.second;
228                        }
229                }
230        }
231
232        @Override
233        public List<IntFloatPair> searchKNN(T query, int K) {
234                // Fix for when the user asks for too many points.
235                K = Math.min(K, pnts.size());
236
237                final BoundedPriorityQueue<IntFloatPair> queue =
238                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
239
240                // prepare working data
241                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
242                for (int i = 0; i < K + 1; i++) {
243                        list.add(new IntFloatPair());
244                }
245
246                // search
247                return search(query, queue, list);
248        }
249
250        @Override
251        public IntFloatPair searchNN(final T query) {
252                final BoundedPriorityQueue<IntFloatPair> queue =
253                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
254
255                // prepare working data
256                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
257                list.add(new IntFloatPair());
258                list.add(new IntFloatPair());
259
260                return search(query, queue, list).get(0);
261        }
262
263        private List<IntFloatPair> search(T query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results)
264        {
265                IntFloatPair wp = null;
266
267                // reset all values in the queue to MAX, -1
268                for (final IntFloatPair p : results) {
269                        p.second = Float.MAX_VALUE;
270                        p.first = -1;
271                        wp = queue.offerItem(p);
272                }
273
274                // perform the search
275                final int size = this.pnts.size();
276                for (int i = 0; i < size; i++) {
277                        wp.second = ObjectNearestNeighbours.distanceFunc(distance, query, pnts.get(i));
278                        wp.first = i;
279                        wp = queue.offerItem(wp);
280                }
281
282                return queue.toOrderedListDestructive();
283        }
284
285        @Override
286        public int size() {
287                return this.pnts.size();
288        }
289
290        @Override
291        public int[] addAll(final List<T> d) {
292                final int[] indexes = new int[d.size()];
293
294                for (int i = 0; i < indexes.length; i++) {
295                        indexes[i] = this.add(d.get(i));
296                }
297
298                return indexes;
299        }
300
301        @Override
302        public int add(final T o) {
303                final int ret = this.pnts.size();
304                this.pnts.add(o);
305                return ret;
306        }
307}