001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.annotation.basic;
031
032import gnu.trove.iterator.TObjectIntIterator;
033import gnu.trove.map.hash.TObjectIntHashMap;
034
035import java.util.ArrayList;
036import java.util.Collection;
037import java.util.HashSet;
038import java.util.List;
039import java.util.Set;
040
041import org.openimaj.feature.FeatureExtractor;
042import org.openimaj.knn.ObjectNearestNeighbours;
043import org.openimaj.knn.ObjectNearestNeighboursExact;
044import org.openimaj.ml.annotation.Annotated;
045import org.openimaj.ml.annotation.IncrementalAnnotator;
046import org.openimaj.ml.annotation.ScoredAnnotation;
047import org.openimaj.util.comparator.DistanceComparator;
048
049/**
050 * Annotator based on a multi-class k-nearest-neighbour classifier. Uses a
051 * {@link ObjectNearestNeighboursExact} to perform the kNN search, so is
052 * applicable to any objects that can be compared with a
053 * {@link DistanceComparator}.
054 * 
055 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
056 * 
057 * @param <OBJECT>
058 *            Type of object being annotated
059 * @param <ANNOTATION>
060 *            Type of annotation
061 * @param <FEATURE>
062 *            Type of feature produced by extractor
063 */
064public class KNNAnnotator<OBJECT, ANNOTATION, FEATURE>
065                extends
066                IncrementalAnnotator<OBJECT, ANNOTATION>
067{
068        protected int k = 1;
069        protected final List<FEATURE> features = new ArrayList<FEATURE>();
070        protected final List<Collection<ANNOTATION>> annotations = new ArrayList<Collection<ANNOTATION>>();
071        protected final Set<ANNOTATION> annotationsSet = new HashSet<ANNOTATION>();
072        protected ObjectNearestNeighbours<FEATURE> nn;
073        protected DistanceComparator<? super FEATURE> comparator;
074        protected final float threshold;
075        protected FeatureExtractor<FEATURE, OBJECT> extractor;
076
077        /**
078         * Construct with the given extractor, comparator and threshold. The number
079         * of neighbours is set to 1.
080         * <p>
081         * If the comparator defines a distance, then only scores below the distance
082         * will be accepted. If the threshold defines a similarity, then only scores
083         * above the threshold will be accepted.
084         * 
085         * @param extractor
086         *            the extractor
087         * @param comparator
088         *            the comparator
089         * @param threshold
090         *            the threshold for successful matches
091         */
092        public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor,
093                        final DistanceComparator<? super FEATURE> comparator,
094                        final float threshold)
095        {
096                this(extractor, comparator, 1, threshold);
097        }
098
099        /**
100         * Construct with the given extractor and comparator. The number of
101         * neighbours is set to 1. The threshold test is disabled.
102         * 
103         * @param extractor
104         *            the extractor
105         * @param comparator
106         *            the comparator
107         */
108        public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor,
109                        final DistanceComparator<? super FEATURE> comparator)
110        {
111                this(extractor, comparator, 1, Float.MAX_VALUE);
112        }
113
114        /**
115         * Construct with the given extractor, comparator and number of neighbours.
116         * The distance threshold is disabled.
117         * 
118         * @param extractor
119         *            the extractor
120         * @param comparator
121         *            the comparator
122         * @param k
123         *            the number of neighbours
124         */
125        public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor,
126                        final DistanceComparator<? super FEATURE> comparator, final int k)
127        {
128                this(extractor, comparator, k, Float.MAX_VALUE);
129        }
130
131        /**
132         * Construct with the given extractor, comparator, number of neighbours and
133         * threshold.
134         * <p>
135         * If the comparator defines a distance, then only scores below the distance
136         * will be accepted. If the threshold defines a similarity, then only scores
137         * above the threshold will be accepted.
138         * 
139         * @param extractor
140         *            the extractor
141         * @param comparator
142         *            the comparator
143         * @param k
144         *            the number of neighbours
145         * @param threshold
146         *            the threshold on distance for successful matches
147         */
148        public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor,
149                        final DistanceComparator<? super FEATURE> comparator, final int k,
150                        final float threshold)
151        {
152                this.extractor = extractor;
153                this.comparator = comparator;
154                this.k = k;
155                this.threshold = comparator.isDistance() ? threshold : -threshold;
156        }
157
158        /**
159         * Create a new {@link KNNAnnotator} with the given extractor, comparator
160         * and threshold. The number of neighbours is set to 1.
161         * <p>
162         * If the comparator defines a distance, then only scores below the distance
163         * will be accepted. If the threshold defines a similarity, then only scores
164         * above the threshold will be accepted.
165         * 
166         * @param <OBJECT>
167         *            Type of object being annotated
168         * @param <ANNOTATION>
169         *            Type of annotation
170         * @param <EXTRACTOR>
171         *            Type of feature extractor
172         * @param <FEATURE>
173         *            Type of feature produced by extractor
174         * 
175         * @param extractor
176         *            the extractor
177         * @param comparator
178         *            the comparator
179         * @param threshold
180         *            the threshold for successful matches
181         * @return new {@link KNNAnnotator}
182         */
183        public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE>
184                        KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor,
185                                        final DistanceComparator<FEATURE> comparator, final float threshold)
186        {
187                return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>(
188                                extractor, comparator, threshold);
189        }
190
191        /**
192         * Create a new {@link KNNAnnotator} with the given extractor and
193         * comparator. The number of neighbours is set to 1. The threshold test is
194         * disabled.
195         * 
196         * @param <OBJECT>
197         *            Type of object being annotated
198         * @param <ANNOTATION>
199         *            Type of annotation
200         * @param <EXTRACTOR>
201         *            Type of feature extractor
202         * @param <FEATURE>
203         *            Type of feature produced by extractor
204         * 
205         * @param extractor
206         *            the extractor
207         * @param comparator
208         *            the comparator
209         * @return new {@link KNNAnnotator}
210         */
211        public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE>
212                        KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor,
213                                        final DistanceComparator<FEATURE> comparator)
214        {
215                return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>(
216                                extractor, comparator);
217        }
218
219        /**
220         * Create a new {@link KNNAnnotator} with the given extractor, comparator
221         * and number of neighbours. The distance threshold is disabled.
222         * 
223         * @param <OBJECT>
224         *            Type of object being annotated
225         * @param <ANNOTATION>
226         *            Type of annotation
227         * @param <EXTRACTOR>
228         *            Type of feature extractor
229         * @param <FEATURE>
230         *            Type of feature produced by extractor
231         * 
232         * @param extractor
233         *            the extractor
234         * @param comparator
235         *            the comparator
236         * @param k
237         *            the number of neighbours
238         * @return new {@link KNNAnnotator}
239         */
240        public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE>
241                        KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor,
242                                        final DistanceComparator<FEATURE> comparator, final int k)
243        {
244                return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>(
245                                extractor, comparator, k);
246        }
247
248        /**
249         * Create a new {@link KNNAnnotator} with the given extractor, comparator,
250         * number of neighbours and threshold.
251         * <p>
252         * If the comparator defines a distance, then only scores below the distance
253         * will be accepted. If the threshold defines a similarity, then only scores
254         * above the threshold will be accepted.
255         * 
256         * @param <OBJECT>
257         *            Type of object being annotated
258         * @param <ANNOTATION>
259         *            Type of annotation
260         * @param <EXTRACTOR>
261         *            Type of feature extractor
262         * @param <FEATURE>
263         *            Type of feature produced by extractor
264         * 
265         * @param extractor
266         *            the extractor
267         * @param comparator
268         *            the comparator
269         * @param k
270         *            the number of neighbours
271         * @param threshold
272         *            the threshold on distance for successful matches
273         * @return new {@link KNNAnnotator}
274         */
275        public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE>
276                        KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor,
277                                        final DistanceComparator<FEATURE> comparator, final int k, final float threshold)
278        {
279                return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>(
280                                extractor, comparator, k, threshold);
281        }
282
283        @Override
284        public void train(final Annotated<OBJECT, ANNOTATION> annotated) {
285                this.nn = null;
286
287                this.features.add(this.extractor.extractFeature(annotated.getObject()));
288
289                final Collection<ANNOTATION> anns = annotated.getAnnotations();
290                this.annotations.add(anns);
291                this.annotationsSet.addAll(anns);
292        }
293
294        @Override
295        public void reset() {
296                this.nn = null;
297                this.features.clear();
298                this.annotations.clear();
299                this.annotationsSet.clear();
300        }
301
302        @Override
303        public Set<ANNOTATION> getAnnotations() {
304                return this.annotationsSet;
305        }
306
307        @Override
308        public List<ScoredAnnotation<ANNOTATION>> annotate(final OBJECT object) {
309                if (this.nn == null)
310                        this.nn = new ObjectNearestNeighboursExact<FEATURE>(this.features, this.comparator);
311
312                final TObjectIntHashMap<ANNOTATION> selected = new TObjectIntHashMap<ANNOTATION>();
313
314                final List<FEATURE> queryfv = new ArrayList<FEATURE>(1);
315                queryfv.add(this.extractor.extractFeature(object));
316
317                final int[][] indices = new int[1][this.k];
318                final float[][] distances = new float[1][this.k];
319
320                this.nn.searchKNN(queryfv, this.k, indices, distances);
321
322                int count = 0;
323                for (int i = 0; i < this.k; i++) {
324                        // Distance check
325                        if (distances[0][i] > this.threshold) {
326                                continue;
327                        }
328
329                        final Collection<ANNOTATION> anns = this.annotations.get(indices[0][i]);
330
331                        for (final ANNOTATION ann : anns) {
332                                selected.adjustOrPutValue(ann, 1, 1);
333                                count++;
334                        }
335                }
336
337                final TObjectIntIterator<ANNOTATION> iterator = selected.iterator();
338                final List<ScoredAnnotation<ANNOTATION>> result = new ArrayList<ScoredAnnotation<ANNOTATION>>(selected.size());
339                while (iterator.hasNext()) {
340                        iterator.advance();
341
342                        result.add(new ScoredAnnotation<ANNOTATION>(iterator.key(), (float) iterator.value() / (float) count));
343                }
344
345                return result;
346        }
347
348        /**
349         * @return the number of neighbours to search for
350         */
351        public int getK() {
352                return this.k;
353        }
354
355        /**
356         * Set the number of neighbours
357         * 
358         * @param k
359         *            the number of neighbours
360         */
361        public void setK(final int k) {
362                this.k = k;
363        }
364}