001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030/** 031 * 032 */ 033package org.openimaj.ml.annotation.svm; 034 035import java.io.File; 036import java.io.IOException; 037import java.util.Collection; 038import java.util.Collections; 039import java.util.HashMap; 040import java.util.HashSet; 041import java.util.List; 042import java.util.Set; 043 044import libsvm.svm; 045import libsvm.svm_model; 046import libsvm.svm_node; 047import libsvm.svm_parameter; 048import libsvm.svm_problem; 049 050import org.openimaj.feature.FeatureExtractor; 051import org.openimaj.feature.FeatureVector; 052import org.openimaj.ml.annotation.Annotated; 053import org.openimaj.ml.annotation.BatchAnnotator; 054import org.openimaj.ml.annotation.ScoredAnnotation; 055import org.openimaj.ml.annotation.utils.AnnotatedListHelper; 056import org.openimaj.util.array.ArrayUtils; 057 058/** 059 * Wraps the libsvm SVM and provides basic positive/negative 060 * annotation for a single class. 061 * 062 * @author David Dupplaw (dpd@ecs.soton.ac.uk) 063 * @created 14 May 2013 064 * 065 * @param <OBJECT> The object being annotated 066 * @param <ANNOTATION> The type of the annotation 067 */ 068public class SVMAnnotator<OBJECT,ANNOTATION> extends BatchAnnotator<OBJECT,ANNOTATION> 069{ 070 /** The input to the SVM model for positive classes */ 071 public static final int POSITIVE_CLASS = +1; 072 073 /** The input to the SVM model for negative classes */ 074 public static final int NEGATIVE_CLASS = -1; 075 076 /** Stores the mapping between the positive and negative class and the annotation */ 077 public HashMap<Integer,ANNOTATION> classMap = new HashMap<Integer,ANNOTATION>(); 078 079 /** The libsvm SVM model */ 080 private svm_model model = null; 081 082 /** The feature extractor being used to extract features from OBJECT */ 083 private FeatureExtractor<? extends FeatureVector, OBJECT> extractor = null; 084 085 /** The file to save the model to after it is trained (or null for no saving) */ 086 private File saveModel = null; 087 088 /** 089 * Constructor that takes the feature extractor to use. 090 * @param extractor The feature extractor 091 */ 092 public SVMAnnotator( final FeatureExtractor<? extends FeatureVector, OBJECT> extractor ) 093 { 094 this.extractor = extractor; 095 } 096 097 /** 098 * {@inheritDoc} 099 * @see org.openimaj.ml.training.BatchTrainer#train(java.util.List) 100 */ 101 @Override 102 public void train( final List<? extends Annotated<OBJECT, ANNOTATION>> data ) 103 { 104 // Check the data has 2 classes and update the class map. 105 if( this.checkInputDataOK( data ) ) 106 { 107 // Setup the SVM problem 108 final svm_parameter param = SVMAnnotator.getDefaultSVMParameters(); 109 final svm_problem prob = this.getSVMProblem( data, param, this.extractor ); 110 111 // Train the SVM 112 this.model = libsvm.svm.svm_train( prob, param ); 113 114 // Save the model if we're going to do that. 115 if( this.saveModel != null ) try 116 { 117 svm.svm_save_model( this.saveModel.getAbsolutePath(), this.model ); 118 } 119 catch( final IOException e ) 120 { 121 e.printStackTrace(); 122 } 123 } 124 } 125 126 /** 127 * Checks that the input data only has 2 classes. The method also assigns 128 * those classes to positive or negative values. 129 * 130 * @param data The data 131 * @return TRUE if and only if there are 2 classes 132 */ 133 private boolean checkInputDataOK( final List<? extends Annotated<OBJECT, ANNOTATION>> data ) 134 { 135 // Clear the class map. We don't want any old annotations left in there. 136 this.classMap.clear(); 137 138 // Loop over the data and check for valid data. 139 int i = 0; 140 for( final Annotated<OBJECT,ANNOTATION> x : data ) 141 { 142 // Get the annotations for the object. 143 final Collection<ANNOTATION> anns = x.getAnnotations(); 144 145 // Check there is only one annotation on each object. 146 if( anns.size() != 1 ) 147 throw new IllegalArgumentException( "Data contained an object with more than one annotation" ); 148 149 // Get the only annotation. 150 final ANNOTATION onlyAnnotation = anns.iterator().next(); 151 152 // Check if it's already been seen. 153 if( !this.classMap.values().contains( onlyAnnotation ) ) 154 { 155 // Key will be -1, +1, +3... 156 final int key = i * 2 -1; 157 i++; 158 159 // Put the first annotation into the map at the appropriate place. 160 this.classMap.put( key, onlyAnnotation ); 161 } 162 } 163 164 // If the data didn't contain 2 classes (i.e. positive and negative) we cannot go on 165 if( this.classMap.keySet().size() != 2 ) 166 { 167 throw new IllegalArgumentException( "Data did not contain exactly 2 classes. It had "+this.classMap.keySet().size()+". They were "+this.classMap ); 168 } 169 170 return true; 171 } 172 173 /** 174 * {@inheritDoc} 175 * @see org.openimaj.ml.annotation.Annotator#getAnnotations() 176 */ 177 @Override 178 public Set<ANNOTATION> getAnnotations() 179 { 180 final HashSet<ANNOTATION> hs = new HashSet<ANNOTATION>(); 181 hs.addAll( this.classMap.values() ); 182 return hs; 183 } 184 185 /** 186 * {@inheritDoc} 187 * @see org.openimaj.ml.annotation.Annotator#annotate(java.lang.Object) 188 */ 189 @Override 190 public List<ScoredAnnotation<ANNOTATION>> annotate( final OBJECT object ) 191 { 192 // Extract the feature and convert to a svm_node[] 193 final svm_node[] nodes = SVMAnnotator.featureToNode( this.extractor.extractFeature( object ) ); 194 195 // Use the trained SVM model to predict the new buffer's annotation 196 final double x = svm.svm_predict( this.model, nodes ); 197 198 // Create a singleton list to contain the classified annotation. 199 return Collections.singletonList( new ScoredAnnotation<ANNOTATION>( x > 0 ? 200 this.classMap.get(SVMAnnotator.POSITIVE_CLASS) : this.classMap.get(SVMAnnotator.NEGATIVE_CLASS), 201 1.0f ) ); 202 } 203 204 /** 205 * Set whether to save the SVM model to disk. 206 * @param saveModel The file name to save to, or null to disable saving. 207 */ 208 public void setSaveModel( final File saveModel ) 209 { 210 this.saveModel = saveModel; 211 } 212 213 /** 214 * Load an existing svm model. 215 * 216 * @param loadModel The model to load from 217 * @throws IOException If the loading does not complete 218 */ 219 public void loadModel( final File loadModel ) throws IOException 220 { 221 this.model = svm.svm_load_model( loadModel.getAbsolutePath() ); 222 } 223 224 /** 225 * Performs cross-validation on the SVM. 226 * 227 * @param data The data 228 * @param numFold The number of folds 229 * @return The calculated accuracy 230 */ 231 public double crossValidation( final List<? extends Annotated<OBJECT, ANNOTATION>> data, 232 final int numFold ) 233 { 234 // Setup the SVM problem 235 final svm_parameter param = SVMAnnotator.getDefaultSVMParameters(); 236 final svm_problem prob = this.getSVMProblem( data, param, this.extractor ); 237 238 return SVMAnnotator.crossValidation( prob, param, numFold ); 239 } 240 241 /** 242 * Performs cross-validation on the SVM. 243 * 244 * @param prob The problem 245 * @param param The parameters 246 * @param numFold The number of folds 247 * @return The accuracy 248 */ 249 static public double crossValidation( final svm_problem prob, 250 final svm_parameter param, final int numFold ) 251 { 252 // The target array in which the final classifications are put 253 final double[] target = new double[prob.l]; 254 255 // Perform the cross-validation. 256 svm.svm_cross_validation( prob, param, numFold, target ); 257 258 // Work out how many classifications were correct. 259 int totalCorrect = 0; 260 for( int i = 0; i < prob.l; i++ ) 261 if( target[i] == prob.y[i] ) 262 totalCorrect++; 263 264 // Calculate the accuracy 265 final double accuracy = 100.0 * totalCorrect / prob.l; 266 System.out.print("Cross Validation Accuracy = "+accuracy+"%\n"); 267 268 return accuracy; 269 } 270 271 // ========================================================================================= // 272 // Static methods below here. 273 // ========================================================================================= // 274 275 /** 276 * Returns the default set of SVM parameters. 277 * @return The default set of SVM parameters 278 */ 279 static private svm_parameter getDefaultSVMParameters() 280 { 281 // These default values came from: 282 // https://github.com/arnaudsj/libsvm/blob/master/java/svm_train.java 283 284 final svm_parameter param = new svm_parameter(); 285 param.svm_type = svm_parameter.C_SVC; 286 param.kernel_type = svm_parameter.RBF; 287 param.degree = 3; 288 param.gamma = 0; // 1/num_features 289 param.coef0 = 0; 290 param.nu = 0.5; 291 param.cache_size = 100; 292 param.C = 1; 293 param.eps = 1e-3; 294 param.p = 0.1; 295 param.shrinking = 1; 296 param.probability = 0; 297 param.nr_weight = 0; 298 param.weight_label = new int[0]; 299 param.weight = new double[0]; 300 301 return param; 302 } 303 304 /** 305 * Returns an svm_problem for the given dataset. This function will return 306 * a new SVM problem and will also side-affect the gamma member of the 307 * param argument. 308 * 309 * @param trainingCorpus The corpus 310 * @param param The SVM parameters 311 * @param featureExtractor The feature extractor to use 312 * @param positiveClass The name of the positive class in the dataset 313 * @param negativeClasses The names of the negative classes in the dataset 314 * @return A new SVM problem. 315 */ 316 private svm_problem getSVMProblem( final List<? extends Annotated<OBJECT, ANNOTATION>> data, 317 final svm_parameter param, final FeatureExtractor<? extends FeatureVector, OBJECT> extractor ) 318 { 319 // Get all the nodes for the features 320 final svm_node[][] positiveNodes = this.computeFeature( 321 data, this.classMap.get( SVMAnnotator.POSITIVE_CLASS ) ); 322 final svm_node[][] negativeNodes = this.computeFeature( 323 data, this.classMap.get( SVMAnnotator.NEGATIVE_CLASS ) ); 324 325 // Work out how long the problem is 326 final int nSamples = positiveNodes.length + negativeNodes.length; 327 328 // The array that determines whether a sample is positive or negative. 329 final double[] flagArray = new double[nSamples]; 330 ArrayUtils.fill( flagArray, SVMAnnotator.POSITIVE_CLASS, 0, positiveNodes.length ); 331 ArrayUtils.fill( flagArray, SVMAnnotator.NEGATIVE_CLASS, positiveNodes.length, negativeNodes.length ); 332 333 // Concatenate the samples to a single array 334 final svm_node[][] sampleArray = ArrayUtils.concatenate( 335 positiveNodes, negativeNodes ); 336 337 // Create the svm problem to solve 338 final svm_problem prob = new svm_problem(); 339 340 // Setup the problem 341 prob.l = nSamples; 342 prob.x = sampleArray; 343 prob.y = flagArray; 344 param.gamma = 1.0 / SVMAnnotator.getMaxIndex( sampleArray ); 345 346 return prob; 347 } 348 349 /** 350 * Computes all the features for a given annotation in the data set and returns 351 * a set of SVM nodes that represent those features. 352 * 353 * @param data The data 354 * @param annotation The annotation of the objects to pick out 355 * @return 356 */ 357 private svm_node[][] computeFeature( 358 final List<? extends Annotated<OBJECT, ANNOTATION>> data, 359 final ANNOTATION annotation ) 360 { 361 // Extract the features for the given annotation 362 final AnnotatedListHelper<OBJECT,ANNOTATION> alh = new AnnotatedListHelper<OBJECT,ANNOTATION>(data); 363 final List<? extends FeatureVector> f = alh.extractFeatures( annotation, this.extractor ); 364 365 // Create the output value - a 2D array of svm_nodes. 366 final svm_node[][] n = new svm_node[ f.size() ][]; 367 368 // Loop over each feature and convert it to an svm_node[] 369 int i = 0; 370 for( final FeatureVector feature : f ) 371 n[i++] = SVMAnnotator.featureToNode( feature ); 372 373 return n; 374 } 375 376 /** 377 * Returns the maximum index value from all the svm_nodes in the array. 378 * @param sampleArray The array of training samples 379 * @return The max feature index 380 */ 381 static private int getMaxIndex( final svm_node[][] sampleArray ) 382 { 383 int max = 0; 384 for( final svm_node[] x : sampleArray ) 385 for( int j = 0; j < x.length; j++ ) 386 max = Math.max( max, x[j].index ); 387 return max; 388 } 389 390 /** 391 * Takes a {@link FeatureVector} and converts it into an array of {@link svm_node}s 392 * for the svm library. 393 * 394 * @param featureVector The feature vector to convert 395 * @return The equivalent svm_node[] 396 */ 397 static private svm_node[] featureToNode( final FeatureVector featureVector ) 398 { 399// if( featureVector instanceof SparseFeatureVector ) 400// { 401// 402// } 403// else 404 { 405 final double[] fv = featureVector.asDoubleVector(); 406 final svm_node[] nodes = new svm_node[fv.length]; 407 408 for( int i = 0; i < fv.length; i++ ) 409 { 410 nodes[i] = new svm_node(); 411 nodes[i].index = i; 412 nodes[i].value = fv[i]; 413 } 414 415 return nodes; 416 } 417 } 418}