001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.annotation.linear;
031
032import gov.sandia.cognition.learning.algorithm.svm.PrimalEstimatedSubGradient;
033import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
034import gov.sandia.cognition.learning.data.InputOutputPair;
035import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
036import gov.sandia.cognition.math.matrix.Vector;
037import gov.sandia.cognition.math.matrix.VectorFactory;
038import gov.sandia.cognition.math.matrix.Vectorizable;
039
040import java.util.ArrayList;
041import java.util.Collection;
042import java.util.HashMap;
043import java.util.List;
044import java.util.Map;
045import java.util.Set;
046
047import org.openimaj.feature.FeatureExtractor;
048import org.openimaj.feature.FeatureVector;
049import org.openimaj.ml.annotation.Annotated;
050import org.openimaj.ml.annotation.Annotator;
051import org.openimaj.ml.annotation.BatchAnnotator;
052import org.openimaj.ml.annotation.ScoredAnnotation;
053import org.openimaj.ml.annotation.utils.AnnotatedListHelper;
054
055/**
056 * An {@link Annotator} based on a set of linear SVMs (one per annotation).
057 * <p>
058 * The SVMs use the PEGASOS algorithm implemented by the
059 * {@link PrimalEstimatedSubGradient} class.
060 *
061 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
062 *
063 * @param <OBJECT>
064 *            Type of object being annotated
065 * @param <ANNOTATION>
066 *            Type of annotation
067 */
068public class LinearSVMAnnotator<OBJECT, ANNOTATION>
069extends
070BatchAnnotator<OBJECT, ANNOTATION>
071{
072        private final Map<ANNOTATION, LinearBinaryCategorizer> classifiers = new HashMap<ANNOTATION, LinearBinaryCategorizer>();
073        private Set<ANNOTATION> annotations;
074        private ANNOTATION negativeClass;
075        private FeatureExtractor<? extends FeatureVector, OBJECT> extractor;
076
077        /**
078         * Construct a new {@link LinearSVMAnnotator} with the given extractor and
079         * the specified negative class. The negative class is excluded from the
080         * predicted annotations.
081         *
082         * @param extractor
083         *            the extractor
084         * @param negativeClass
085         *            the negative class to exclude from predictions
086         */
087        public LinearSVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor, ANNOTATION negativeClass) {
088                this.extractor = extractor;
089                this.negativeClass = negativeClass;
090        }
091
092        /**
093         * Construct a new {@link LinearSVMAnnotator} with the given extractor.
094         *
095         * @param extractor
096         *            the extractor
097         */
098        public LinearSVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor) {
099                this(extractor, null);
100        }
101
102        @Override
103        public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
104                final AnnotatedListHelper<OBJECT, ANNOTATION> helper = new AnnotatedListHelper<OBJECT, ANNOTATION>(data);
105
106                annotations = helper.getAnnotations();
107
108                for (final ANNOTATION annotation : annotations) {
109                        final PrimalEstimatedSubGradient pegasos = new PrimalEstimatedSubGradient();
110
111                        final List<? extends FeatureVector> positive = helper.extractFeatures(annotation,
112                                        (FeatureExtractor<? extends FeatureVector, OBJECT>) extractor);
113                        final List<? extends FeatureVector> negative = helper.extractFeaturesExclude(annotation,
114                                        (FeatureExtractor<? extends FeatureVector, OBJECT>) extractor);
115
116                        pegasos.learn(convert(positive, negative));
117                        classifiers.put(annotation, pegasos.getResult());
118                }
119        }
120
121        private Collection<? extends InputOutputPair<? extends Vectorizable, Boolean>>
122        convert(List<? extends FeatureVector> positive, List<? extends FeatureVector> negative)
123        {
124                final Collection<InputOutputPair<Vectorizable, Boolean>> data =
125                                new ArrayList<InputOutputPair<Vectorizable, Boolean>>(positive.size() + negative.size());
126
127                for (final FeatureVector p : positive) {
128                        data.add(new DefaultInputOutputPair<Vectorizable, Boolean>(convert(p), true));
129                }
130                for (final FeatureVector n : negative) {
131                        data.add(new DefaultInputOutputPair<Vectorizable, Boolean>(convert(n), false));
132                }
133
134                return data;
135        }
136
137        @Override
138        public Set<ANNOTATION> getAnnotations() {
139                return annotations;
140        }
141
142        @Override
143        public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
144                final List<ScoredAnnotation<ANNOTATION>> results = new ArrayList<ScoredAnnotation<ANNOTATION>>();
145
146                for (final ANNOTATION annotation : annotations) {
147                        // skip the negative class
148                        if (annotation.equals(negativeClass))
149                                continue;
150
151                        final FeatureVector feature = extractor.extractFeature(object);
152                        final Vector vector = convert(feature);
153
154                        final double result = classifiers.get(annotation).evaluateAsDouble(vector);
155
156                        if (result > 0) {
157                                results.add(new ScoredAnnotation<ANNOTATION>(annotation, (float) Math.abs(result)));
158                        }
159                }
160
161                return results;
162        }
163
164        private Vector convert(FeatureVector feature) {
165                return VectorFactory.getDenseDefault().copyArray(feature.asDoubleVector());
166        }
167}