001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.annotation.basic; 031 032import gnu.trove.iterator.TObjectIntIterator; 033import gnu.trove.map.hash.TObjectIntHashMap; 034 035import java.util.ArrayList; 036import java.util.Collection; 037import java.util.HashSet; 038import java.util.List; 039import java.util.Set; 040 041import org.openimaj.feature.FeatureExtractor; 042import org.openimaj.knn.ObjectNearestNeighbours; 043import org.openimaj.knn.ObjectNearestNeighboursExact; 044import org.openimaj.ml.annotation.Annotated; 045import org.openimaj.ml.annotation.IncrementalAnnotator; 046import org.openimaj.ml.annotation.ScoredAnnotation; 047import org.openimaj.util.comparator.DistanceComparator; 048 049/** 050 * Annotator based on a multi-class k-nearest-neighbour classifier. Uses a 051 * {@link ObjectNearestNeighboursExact} to perform the kNN search, so is 052 * applicable to any objects that can be compared with a 053 * {@link DistanceComparator}. 054 * 055 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 056 * 057 * @param <OBJECT> 058 * Type of object being annotated 059 * @param <ANNOTATION> 060 * Type of annotation 061 * @param <FEATURE> 062 * Type of feature produced by extractor 063 */ 064public class KNNAnnotator<OBJECT, ANNOTATION, FEATURE> 065 extends 066 IncrementalAnnotator<OBJECT, ANNOTATION> 067{ 068 protected int k = 1; 069 protected final List<FEATURE> features = new ArrayList<FEATURE>(); 070 protected final List<Collection<ANNOTATION>> annotations = new ArrayList<Collection<ANNOTATION>>(); 071 protected final Set<ANNOTATION> annotationsSet = new HashSet<ANNOTATION>(); 072 protected ObjectNearestNeighbours<FEATURE> nn; 073 protected DistanceComparator<? super FEATURE> comparator; 074 protected final float threshold; 075 protected FeatureExtractor<FEATURE, OBJECT> extractor; 076 077 /** 078 * Construct with the given extractor, comparator and threshold. The number 079 * of neighbours is set to 1. 080 * <p> 081 * If the comparator defines a distance, then only scores below the distance 082 * will be accepted. If the threshold defines a similarity, then only scores 083 * above the threshold will be accepted. 084 * 085 * @param extractor 086 * the extractor 087 * @param comparator 088 * the comparator 089 * @param threshold 090 * the threshold for successful matches 091 */ 092 public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor, 093 final DistanceComparator<? super FEATURE> comparator, 094 final float threshold) 095 { 096 this(extractor, comparator, 1, threshold); 097 } 098 099 /** 100 * Construct with the given extractor and comparator. The number of 101 * neighbours is set to 1. The threshold test is disabled. 102 * 103 * @param extractor 104 * the extractor 105 * @param comparator 106 * the comparator 107 */ 108 public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor, 109 final DistanceComparator<? super FEATURE> comparator) 110 { 111 this(extractor, comparator, 1, Float.MAX_VALUE); 112 } 113 114 /** 115 * Construct with the given extractor, comparator and number of neighbours. 116 * The distance threshold is disabled. 117 * 118 * @param extractor 119 * the extractor 120 * @param comparator 121 * the comparator 122 * @param k 123 * the number of neighbours 124 */ 125 public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor, 126 final DistanceComparator<? super FEATURE> comparator, final int k) 127 { 128 this(extractor, comparator, k, Float.MAX_VALUE); 129 } 130 131 /** 132 * Construct with the given extractor, comparator, number of neighbours and 133 * threshold. 134 * <p> 135 * If the comparator defines a distance, then only scores below the distance 136 * will be accepted. If the threshold defines a similarity, then only scores 137 * above the threshold will be accepted. 138 * 139 * @param extractor 140 * the extractor 141 * @param comparator 142 * the comparator 143 * @param k 144 * the number of neighbours 145 * @param threshold 146 * the threshold on distance for successful matches 147 */ 148 public KNNAnnotator(final FeatureExtractor<FEATURE, OBJECT> extractor, 149 final DistanceComparator<? super FEATURE> comparator, final int k, 150 final float threshold) 151 { 152 this.extractor = extractor; 153 this.comparator = comparator; 154 this.k = k; 155 this.threshold = comparator.isDistance() ? threshold : -threshold; 156 } 157 158 /** 159 * Create a new {@link KNNAnnotator} with the given extractor, comparator 160 * and threshold. The number of neighbours is set to 1. 161 * <p> 162 * If the comparator defines a distance, then only scores below the distance 163 * will be accepted. If the threshold defines a similarity, then only scores 164 * above the threshold will be accepted. 165 * 166 * @param <OBJECT> 167 * Type of object being annotated 168 * @param <ANNOTATION> 169 * Type of annotation 170 * @param <EXTRACTOR> 171 * Type of feature extractor 172 * @param <FEATURE> 173 * Type of feature produced by extractor 174 * 175 * @param extractor 176 * the extractor 177 * @param comparator 178 * the comparator 179 * @param threshold 180 * the threshold for successful matches 181 * @return new {@link KNNAnnotator} 182 */ 183 public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE> 184 KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor, 185 final DistanceComparator<FEATURE> comparator, final float threshold) 186 { 187 return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>( 188 extractor, comparator, threshold); 189 } 190 191 /** 192 * Create a new {@link KNNAnnotator} with the given extractor and 193 * comparator. The number of neighbours is set to 1. The threshold test is 194 * disabled. 195 * 196 * @param <OBJECT> 197 * Type of object being annotated 198 * @param <ANNOTATION> 199 * Type of annotation 200 * @param <EXTRACTOR> 201 * Type of feature extractor 202 * @param <FEATURE> 203 * Type of feature produced by extractor 204 * 205 * @param extractor 206 * the extractor 207 * @param comparator 208 * the comparator 209 * @return new {@link KNNAnnotator} 210 */ 211 public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE> 212 KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor, 213 final DistanceComparator<FEATURE> comparator) 214 { 215 return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>( 216 extractor, comparator); 217 } 218 219 /** 220 * Create a new {@link KNNAnnotator} with the given extractor, comparator 221 * and number of neighbours. The distance threshold is disabled. 222 * 223 * @param <OBJECT> 224 * Type of object being annotated 225 * @param <ANNOTATION> 226 * Type of annotation 227 * @param <EXTRACTOR> 228 * Type of feature extractor 229 * @param <FEATURE> 230 * Type of feature produced by extractor 231 * 232 * @param extractor 233 * the extractor 234 * @param comparator 235 * the comparator 236 * @param k 237 * the number of neighbours 238 * @return new {@link KNNAnnotator} 239 */ 240 public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE> 241 KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor, 242 final DistanceComparator<FEATURE> comparator, final int k) 243 { 244 return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>( 245 extractor, comparator, k); 246 } 247 248 /** 249 * Create a new {@link KNNAnnotator} with the given extractor, comparator, 250 * number of neighbours and threshold. 251 * <p> 252 * If the comparator defines a distance, then only scores below the distance 253 * will be accepted. If the threshold defines a similarity, then only scores 254 * above the threshold will be accepted. 255 * 256 * @param <OBJECT> 257 * Type of object being annotated 258 * @param <ANNOTATION> 259 * Type of annotation 260 * @param <EXTRACTOR> 261 * Type of feature extractor 262 * @param <FEATURE> 263 * Type of feature produced by extractor 264 * 265 * @param extractor 266 * the extractor 267 * @param comparator 268 * the comparator 269 * @param k 270 * the number of neighbours 271 * @param threshold 272 * the threshold on distance for successful matches 273 * @return new {@link KNNAnnotator} 274 */ 275 public static <OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<FEATURE, OBJECT>, FEATURE> 276 KNNAnnotator<OBJECT, ANNOTATION, FEATURE> create(final EXTRACTOR extractor, 277 final DistanceComparator<FEATURE> comparator, final int k, final float threshold) 278 { 279 return new KNNAnnotator<OBJECT, ANNOTATION, FEATURE>( 280 extractor, comparator, k, threshold); 281 } 282 283 @Override 284 public void train(final Annotated<OBJECT, ANNOTATION> annotated) { 285 this.nn = null; 286 287 this.features.add(this.extractor.extractFeature(annotated.getObject())); 288 289 final Collection<ANNOTATION> anns = annotated.getAnnotations(); 290 this.annotations.add(anns); 291 this.annotationsSet.addAll(anns); 292 } 293 294 @Override 295 public void reset() { 296 this.nn = null; 297 this.features.clear(); 298 this.annotations.clear(); 299 this.annotationsSet.clear(); 300 } 301 302 @Override 303 public Set<ANNOTATION> getAnnotations() { 304 return this.annotationsSet; 305 } 306 307 @Override 308 public List<ScoredAnnotation<ANNOTATION>> annotate(final OBJECT object) { 309 if (this.nn == null) 310 this.nn = new ObjectNearestNeighboursExact<FEATURE>(this.features, this.comparator); 311 312 final TObjectIntHashMap<ANNOTATION> selected = new TObjectIntHashMap<ANNOTATION>(); 313 314 final List<FEATURE> queryfv = new ArrayList<FEATURE>(1); 315 queryfv.add(this.extractor.extractFeature(object)); 316 317 final int[][] indices = new int[1][this.k]; 318 final float[][] distances = new float[1][this.k]; 319 320 this.nn.searchKNN(queryfv, this.k, indices, distances); 321 322 int count = 0; 323 for (int i = 0; i < this.k; i++) { 324 // Distance check 325 if (distances[0][i] > this.threshold) { 326 continue; 327 } 328 329 final Collection<ANNOTATION> anns = this.annotations.get(indices[0][i]); 330 331 for (final ANNOTATION ann : anns) { 332 selected.adjustOrPutValue(ann, 1, 1); 333 count++; 334 } 335 } 336 337 final TObjectIntIterator<ANNOTATION> iterator = selected.iterator(); 338 final List<ScoredAnnotation<ANNOTATION>> result = new ArrayList<ScoredAnnotation<ANNOTATION>>(selected.size()); 339 while (iterator.hasNext()) { 340 iterator.advance(); 341 342 result.add(new ScoredAnnotation<ANNOTATION>(iterator.key(), (float) iterator.value() / (float) count)); 343 } 344 345 return result; 346 } 347 348 /** 349 * @return the number of neighbours to search for 350 */ 351 public int getK() { 352 return this.k; 353 } 354 355 /** 356 * Set the number of neighbours 357 * 358 * @param k 359 * the number of neighbours 360 */ 361 public void setK(final int k) { 362 this.k = k; 363 } 364}