001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030/**
031 *
032 */
033package org.openimaj.ml.annotation.svm;
034
035import java.io.File;
036import java.io.IOException;
037import java.util.Collection;
038import java.util.Collections;
039import java.util.HashMap;
040import java.util.HashSet;
041import java.util.List;
042import java.util.Set;
043
044import libsvm.svm;
045import libsvm.svm_model;
046import libsvm.svm_node;
047import libsvm.svm_parameter;
048import libsvm.svm_problem;
049
050import org.openimaj.feature.FeatureExtractor;
051import org.openimaj.feature.FeatureVector;
052import org.openimaj.ml.annotation.Annotated;
053import org.openimaj.ml.annotation.BatchAnnotator;
054import org.openimaj.ml.annotation.ScoredAnnotation;
055import org.openimaj.ml.annotation.utils.AnnotatedListHelper;
056import org.openimaj.util.array.ArrayUtils;
057
058/**
059 *      Wraps the libsvm SVM and provides basic positive/negative
060 *      annotation for a single class.
061 *
062 *      @author David Dupplaw (dpd@ecs.soton.ac.uk)
063 *  @created 14 May 2013
064 *
065 *      @param <OBJECT> The object being annotated
066 *      @param <ANNOTATION> The type of the annotation
067 */
068public class SVMAnnotator<OBJECT,ANNOTATION> extends BatchAnnotator<OBJECT,ANNOTATION>
069{
070        /** The input to the SVM model for positive classes */
071        public static final int POSITIVE_CLASS = +1;
072
073        /** The input to the SVM model for negative classes */
074        public static final int NEGATIVE_CLASS = -1;
075
076        /** Stores the mapping between the positive and negative class and the annotation */
077        public HashMap<Integer,ANNOTATION> classMap = new HashMap<Integer,ANNOTATION>();
078
079        /** The libsvm SVM model */
080        private svm_model model = null;
081
082        /** The feature extractor being used to extract features from OBJECT */
083        private FeatureExtractor<? extends FeatureVector, OBJECT> extractor = null;
084
085        /** The file to save the model to after it is trained (or null for no saving) */
086        private File saveModel = null;
087
088        /**
089         *      Constructor that takes the feature extractor to use.
090         *      @param extractor The feature extractor
091         */
092        public SVMAnnotator( final FeatureExtractor<? extends FeatureVector, OBJECT> extractor )
093        {
094                this.extractor = extractor;
095        }
096
097        /**
098         *      {@inheritDoc}
099         *      @see org.openimaj.ml.training.BatchTrainer#train(java.util.List)
100         */
101        @Override
102        public void train( final List<? extends Annotated<OBJECT, ANNOTATION>> data )
103        {
104                // Check the data has 2 classes and update the class map.
105                if( this.checkInputDataOK( data ) )
106                {
107                        // Setup the SVM problem
108                        final svm_parameter param = SVMAnnotator.getDefaultSVMParameters();
109                        final svm_problem prob = this.getSVMProblem( data, param, this.extractor );
110
111                        // Train the SVM
112                        this.model = libsvm.svm.svm_train( prob, param );
113
114                        // Save the model if we're going to do that.
115                        if( this.saveModel != null ) try
116                        {
117                                svm.svm_save_model( this.saveModel.getAbsolutePath(), this.model );
118                        }
119                        catch( final IOException e )
120                        {
121                                e.printStackTrace();
122                        }
123                }
124        }
125
126        /**
127         *      Checks that the input data only has 2 classes. The method also assigns
128         *      those classes to positive or negative values.
129         *
130         *      @param data The data
131         *      @return TRUE if and only if there are 2 classes
132         */
133        private boolean checkInputDataOK( final List<? extends Annotated<OBJECT, ANNOTATION>> data )
134        {
135                // Clear the class map. We don't want any old annotations left in there.
136                this.classMap.clear();
137
138                // Loop over the data and check for valid data.
139                int i = 0;
140                for( final Annotated<OBJECT,ANNOTATION> x : data )
141                {
142                        // Get the annotations for the object.
143                        final Collection<ANNOTATION> anns = x.getAnnotations();
144
145                        // Check there is only one annotation on each object.
146                        if( anns.size() != 1 )
147                                throw new IllegalArgumentException( "Data contained an object with more than one annotation" );
148
149                        // Get the only annotation.
150                        final ANNOTATION onlyAnnotation = anns.iterator().next();
151
152                        // Check if it's already been seen.
153                        if( !this.classMap.values().contains( onlyAnnotation ) )
154                        {
155                                // Key will be -1, +1, +3...
156                                final int key = i * 2 -1;
157                                i++;
158
159                                // Put the first annotation into the map at the appropriate place.
160                                this.classMap.put( key, onlyAnnotation );
161                        }
162                }
163
164                // If the data didn't contain 2 classes (i.e. positive and negative) we cannot go on
165                if( this.classMap.keySet().size() != 2 )
166                {
167                        throw new IllegalArgumentException( "Data did not contain exactly 2 classes. It had "+this.classMap.keySet().size()+". They were "+this.classMap );
168                }
169
170                return true;
171        }
172
173        /**
174         *      {@inheritDoc}
175         *      @see org.openimaj.ml.annotation.Annotator#getAnnotations()
176         */
177        @Override
178        public Set<ANNOTATION> getAnnotations()
179        {
180                final HashSet<ANNOTATION> hs = new HashSet<ANNOTATION>();
181                hs.addAll( this.classMap.values() );
182                return hs;
183        }
184
185        /**
186         *      {@inheritDoc}
187         *      @see org.openimaj.ml.annotation.Annotator#annotate(java.lang.Object)
188         */
189        @Override
190        public List<ScoredAnnotation<ANNOTATION>> annotate( final OBJECT object )
191        {
192                // Extract the feature and convert to a svm_node[]
193                final svm_node[] nodes = SVMAnnotator.featureToNode( this.extractor.extractFeature( object ) );
194
195                // Use the trained SVM model to predict the new buffer's annotation
196                final double x = svm.svm_predict( this.model, nodes );
197
198                // Create a singleton list to contain the classified annotation.
199                return Collections.singletonList( new ScoredAnnotation<ANNOTATION>( x > 0 ?
200                                this.classMap.get(SVMAnnotator.POSITIVE_CLASS) : this.classMap.get(SVMAnnotator.NEGATIVE_CLASS),
201                                1.0f ) );
202        }
203
204        /**
205         *      Set whether to save the SVM model to disk.
206         *      @param saveModel The file name to save to, or null to disable saving.
207         */
208        public void setSaveModel( final File saveModel )
209        {
210                this.saveModel = saveModel;
211        }
212
213        /**
214         *      Load an existing svm model.
215         *
216         *      @param loadModel The model to load from
217         *      @throws IOException If the loading does not complete
218         */
219        public void loadModel( final File loadModel ) throws IOException
220        {
221                this.model = svm.svm_load_model( loadModel.getAbsolutePath() );
222        }
223
224        /**
225         *      Performs cross-validation on the SVM.
226         *
227         *      @param data The data
228         *      @param numFold The number of folds
229         *      @return The calculated accuracy
230         */
231        public double crossValidation( final List<? extends Annotated<OBJECT, ANNOTATION>> data,
232                        final int numFold )
233        {
234                // Setup the SVM problem
235                final svm_parameter param = SVMAnnotator.getDefaultSVMParameters();
236                final svm_problem prob = this.getSVMProblem( data, param, this.extractor );
237
238                return SVMAnnotator.crossValidation( prob, param, numFold );
239        }
240
241        /**
242         *      Performs cross-validation on the SVM.
243         *
244         *      @param prob The problem
245         *      @param param The parameters
246         *      @param numFold The number of folds
247         *      @return The accuracy
248         */
249        static public double crossValidation( final svm_problem prob,
250                        final svm_parameter param, final int numFold )
251        {
252                // The target array in which the final classifications are put
253                final double[] target = new double[prob.l];
254
255                // Perform the cross-validation.
256                svm.svm_cross_validation( prob, param, numFold, target );
257
258                // Work out how many classifications were correct.
259                int totalCorrect = 0;
260                for( int i = 0; i < prob.l; i++ )
261                        if( target[i] == prob.y[i] )
262                                totalCorrect++;
263
264                // Calculate the accuracy
265                final double accuracy = 100.0 * totalCorrect / prob.l;
266                System.out.print("Cross Validation Accuracy = "+accuracy+"%\n");
267
268                return accuracy;
269        }
270
271        // ========================================================================================= //
272        // Static methods below here.
273        // ========================================================================================= //
274
275        /**
276         *      Returns the default set of SVM parameters.
277         *      @return The default set of SVM parameters
278         */
279        static private svm_parameter getDefaultSVMParameters()
280        {
281                // These default values came from:
282                // https://github.com/arnaudsj/libsvm/blob/master/java/svm_train.java
283
284                final svm_parameter param = new svm_parameter();
285                param.svm_type = svm_parameter.C_SVC;
286                param.kernel_type = svm_parameter.RBF;
287                param.degree = 3;
288                param.gamma = 0;        // 1/num_features
289                param.coef0 = 0;
290                param.nu = 0.5;
291                param.cache_size = 100;
292                param.C = 1;
293                param.eps = 1e-3;
294                param.p = 0.1;
295                param.shrinking = 1;
296                param.probability = 0;
297                param.nr_weight = 0;
298                param.weight_label = new int[0];
299                param.weight = new double[0];
300
301                return param;
302        }
303
304        /**
305         *      Returns an svm_problem for the given dataset. This function will return
306         *      a new SVM problem and will also side-affect the gamma member of the
307         *      param argument.
308         *
309         *      @param trainingCorpus The corpus
310         *      @param param The SVM parameters
311         *      @param featureExtractor The feature extractor to use
312         *      @param positiveClass The name of the positive class in the dataset
313         *      @param negativeClasses The names of the negative classes in the dataset
314         *      @return A new SVM problem.
315         */
316        private svm_problem getSVMProblem( final List<? extends Annotated<OBJECT, ANNOTATION>> data,
317                final svm_parameter param, final FeatureExtractor<? extends FeatureVector, OBJECT> extractor  )
318        {
319                // Get all the nodes for the features
320                final svm_node[][] positiveNodes = this.computeFeature(
321                                data, this.classMap.get( SVMAnnotator.POSITIVE_CLASS ) );
322                final svm_node[][] negativeNodes = this.computeFeature(
323                                data, this.classMap.get( SVMAnnotator.NEGATIVE_CLASS ) );
324
325                // Work out how long the problem is
326                final int nSamples = positiveNodes.length + negativeNodes.length;
327
328                // The array that determines whether a sample is positive or negative.
329                final double[] flagArray = new double[nSamples];
330                ArrayUtils.fill( flagArray, SVMAnnotator.POSITIVE_CLASS, 0, positiveNodes.length );
331                ArrayUtils.fill( flagArray, SVMAnnotator.NEGATIVE_CLASS, positiveNodes.length, negativeNodes.length );
332
333                // Concatenate the samples to a single array
334                final svm_node[][] sampleArray = ArrayUtils.concatenate(
335                                positiveNodes, negativeNodes );
336
337                // Create the svm problem to solve
338                final svm_problem prob = new svm_problem();
339
340                // Setup the problem
341                prob.l = nSamples;
342                prob.x = sampleArray;
343                prob.y = flagArray;
344                param.gamma = 1.0 / SVMAnnotator.getMaxIndex( sampleArray );
345
346                return prob;
347        }
348
349        /**
350         *      Computes all the features for a given annotation in the data set and returns
351         *      a set of SVM nodes that represent those features.
352         *
353         *      @param data The data
354         *      @param annotation The annotation of the objects to pick out
355         *      @return
356         */
357        private svm_node[][] computeFeature(
358                        final List<? extends Annotated<OBJECT, ANNOTATION>> data,
359                        final ANNOTATION annotation )
360        {
361                // Extract the features for the given annotation
362                final AnnotatedListHelper<OBJECT,ANNOTATION> alh = new AnnotatedListHelper<OBJECT,ANNOTATION>(data);
363                final List<? extends FeatureVector> f = alh.extractFeatures( annotation, this.extractor );
364
365                // Create the output value - a 2D array of svm_nodes.
366                final svm_node[][] n = new svm_node[ f.size() ][];
367
368                // Loop over each feature and convert it to an svm_node[]
369                int i = 0;
370                for( final FeatureVector feature : f )
371                        n[i++] = SVMAnnotator.featureToNode( feature );
372
373                return n;
374        }
375
376        /**
377         *      Returns the maximum index value from all the svm_nodes in the array.
378         *      @param sampleArray The array of training samples
379         *      @return The max feature index
380         */
381        static private int getMaxIndex( final svm_node[][] sampleArray )
382        {
383                int max = 0;
384                for( final svm_node[] x : sampleArray )
385                        for( int j = 0; j < x.length; j++ )
386                                max = Math.max( max, x[j].index );
387                return max;
388        }
389
390        /**
391         *      Takes a {@link FeatureVector} and converts it into an array of {@link svm_node}s
392         *      for the svm library.
393         *
394         *      @param featureVector The feature vector to convert
395         *      @return The equivalent svm_node[]
396         */
397        static private svm_node[] featureToNode( final FeatureVector featureVector )
398        {
399//              if( featureVector instanceof SparseFeatureVector )
400//              {
401//
402//              }
403//              else
404                {
405                        final double[] fv = featureVector.asDoubleVector();
406                        final svm_node[] nodes = new svm_node[fv.length];
407
408                        for( int i = 0; i < fv.length; i++ )
409                        {
410                                nodes[i] = new svm_node();
411                                nodes[i].index = i;
412                                nodes[i].value = fv[i];
413                        }
414
415                        return nodes;
416                }
417        }
418}