001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.annotation.linear; 031 032import gov.sandia.cognition.learning.algorithm.svm.PrimalEstimatedSubGradient; 033import gov.sandia.cognition.learning.data.DefaultInputOutputPair; 034import gov.sandia.cognition.learning.data.InputOutputPair; 035import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer; 036import gov.sandia.cognition.math.matrix.Vector; 037import gov.sandia.cognition.math.matrix.VectorFactory; 038import gov.sandia.cognition.math.matrix.Vectorizable; 039 040import java.util.ArrayList; 041import java.util.Collection; 042import java.util.HashMap; 043import java.util.List; 044import java.util.Map; 045import java.util.Set; 046 047import org.openimaj.feature.FeatureExtractor; 048import org.openimaj.feature.FeatureVector; 049import org.openimaj.ml.annotation.Annotated; 050import org.openimaj.ml.annotation.Annotator; 051import org.openimaj.ml.annotation.BatchAnnotator; 052import org.openimaj.ml.annotation.ScoredAnnotation; 053import org.openimaj.ml.annotation.utils.AnnotatedListHelper; 054 055/** 056 * An {@link Annotator} based on a set of linear SVMs (one per annotation). 057 * <p> 058 * The SVMs use the PEGASOS algorithm implemented by the 059 * {@link PrimalEstimatedSubGradient} class. 060 * 061 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 062 * 063 * @param <OBJECT> 064 * Type of object being annotated 065 * @param <ANNOTATION> 066 * Type of annotation 067 */ 068public class LinearSVMAnnotator<OBJECT, ANNOTATION> 069extends 070BatchAnnotator<OBJECT, ANNOTATION> 071{ 072 private final Map<ANNOTATION, LinearBinaryCategorizer> classifiers = new HashMap<ANNOTATION, LinearBinaryCategorizer>(); 073 private Set<ANNOTATION> annotations; 074 private ANNOTATION negativeClass; 075 private FeatureExtractor<? extends FeatureVector, OBJECT> extractor; 076 077 /** 078 * Construct a new {@link LinearSVMAnnotator} with the given extractor and 079 * the specified negative class. The negative class is excluded from the 080 * predicted annotations. 081 * 082 * @param extractor 083 * the extractor 084 * @param negativeClass 085 * the negative class to exclude from predictions 086 */ 087 public LinearSVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor, ANNOTATION negativeClass) { 088 this.extractor = extractor; 089 this.negativeClass = negativeClass; 090 } 091 092 /** 093 * Construct a new {@link LinearSVMAnnotator} with the given extractor. 094 * 095 * @param extractor 096 * the extractor 097 */ 098 public LinearSVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor) { 099 this(extractor, null); 100 } 101 102 @Override 103 public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) { 104 final AnnotatedListHelper<OBJECT, ANNOTATION> helper = new AnnotatedListHelper<OBJECT, ANNOTATION>(data); 105 106 annotations = helper.getAnnotations(); 107 108 for (final ANNOTATION annotation : annotations) { 109 final PrimalEstimatedSubGradient pegasos = new PrimalEstimatedSubGradient(); 110 111 final List<? extends FeatureVector> positive = helper.extractFeatures(annotation, 112 (FeatureExtractor<? extends FeatureVector, OBJECT>) extractor); 113 final List<? extends FeatureVector> negative = helper.extractFeaturesExclude(annotation, 114 (FeatureExtractor<? extends FeatureVector, OBJECT>) extractor); 115 116 pegasos.learn(convert(positive, negative)); 117 classifiers.put(annotation, pegasos.getResult()); 118 } 119 } 120 121 private Collection<? extends InputOutputPair<? extends Vectorizable, Boolean>> 122 convert(List<? extends FeatureVector> positive, List<? extends FeatureVector> negative) 123 { 124 final Collection<InputOutputPair<Vectorizable, Boolean>> data = 125 new ArrayList<InputOutputPair<Vectorizable, Boolean>>(positive.size() + negative.size()); 126 127 for (final FeatureVector p : positive) { 128 data.add(new DefaultInputOutputPair<Vectorizable, Boolean>(convert(p), true)); 129 } 130 for (final FeatureVector n : negative) { 131 data.add(new DefaultInputOutputPair<Vectorizable, Boolean>(convert(n), false)); 132 } 133 134 return data; 135 } 136 137 @Override 138 public Set<ANNOTATION> getAnnotations() { 139 return annotations; 140 } 141 142 @Override 143 public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) { 144 final List<ScoredAnnotation<ANNOTATION>> results = new ArrayList<ScoredAnnotation<ANNOTATION>>(); 145 146 for (final ANNOTATION annotation : annotations) { 147 // skip the negative class 148 if (annotation.equals(negativeClass)) 149 continue; 150 151 final FeatureVector feature = extractor.extractFeature(object); 152 final Vector vector = convert(feature); 153 154 final double result = classifiers.get(annotation).evaluateAsDouble(vector); 155 156 if (result > 0) { 157 results.add(new ScoredAnnotation<ANNOTATION>(annotation, (float) Math.abs(result))); 158 } 159 } 160 161 return results; 162 } 163 164 private Vector convert(FeatureVector feature) { 165 return VectorFactory.getDenseDefault().copyArray(feature.asDoubleVector()); 166 } 167}