001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.annotation.linear;
031
032import java.util.ArrayList;
033import java.util.Collection;
034import java.util.Collections;
035import java.util.Comparator;
036import java.util.HashSet;
037import java.util.List;
038import java.util.Set;
039
040import org.openimaj.citation.annotation.Reference;
041import org.openimaj.citation.annotation.ReferenceType;
042import org.openimaj.feature.FeatureExtractor;
043import org.openimaj.feature.FeatureVector;
044import org.openimaj.math.matrix.PseudoInverse;
045import org.openimaj.ml.annotation.Annotated;
046import org.openimaj.ml.annotation.BatchAnnotator;
047import org.openimaj.ml.annotation.ScoredAnnotation;
048
049import Jama.Matrix;
050
051/**
052 * An annotator that determines a "transform" between feature vectors and
053 * vectors of annotation counts. The transform is estimated using a lossy pseudo
054 * inverse; the single parameter of the algorithm is the desired rank of the
055 * transform matrix.
056 * 
057 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
058 * 
059 * @param <OBJECT>
060 *            Type of object being annotated
061 * @param <ANNOTATION>
062 *            Type of annotation
063 */
064@Reference(
065                type = ReferenceType.Inproceedings,
066                author = { "Jonathan Hare", "Paul Lewis" },
067                title = "Semantic Retrieval and Automatic Annotation: Linear Transformations, Correlation and Semantic Spaces",
068                year = "2010",
069                booktitle = "Imaging and Printing in a Web 2.0 World; and Multimedia Content Access: Algorithms and Systems IV",
070                url = "http://eprints.soton.ac.uk/268496/",
071                note = " Event Dates: 17-21 Jan 2010",
072                month = "January",
073                publisher = "SPIE",
074                volume = "7540")
075public class DenseLinearTransformAnnotator<OBJECT, ANNOTATION>
076                extends
077                BatchAnnotator<OBJECT, ANNOTATION>
078{
079        protected List<ANNOTATION> terms;
080        protected Matrix transform;
081        protected int k = 10;
082        private FeatureExtractor<? extends FeatureVector, OBJECT> extractor;
083
084        /**
085         * Construct with the given number of dimensions and feature extractor.
086         * 
087         * @param k
088         *            the number of dimensions (rank of the pseudo-inverse)
089         * @param extractor
090         *            the feature extractor
091         */
092        public DenseLinearTransformAnnotator(int k, FeatureExtractor<? extends FeatureVector, OBJECT> extractor) {
093                this.extractor = extractor;
094                this.k = k;
095        }
096
097        @Override
098        public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
099                final Set<ANNOTATION> termsSet = new HashSet<ANNOTATION>();
100
101                for (final Annotated<OBJECT, ANNOTATION> d : data)
102                        termsSet.addAll(d.getAnnotations());
103                terms = new ArrayList<ANNOTATION>(termsSet);
104
105                final int termLen = terms.size();
106                final int trainingLen = data.size();
107
108                final Annotated<OBJECT, ANNOTATION> first = data.get(0);
109                final double[] fv = extractor.extractFeature(first.getObject()).asDoubleVector();
110
111                final int featureLen = fv.length;
112
113                final Matrix F = new Matrix(trainingLen, featureLen);
114                final Matrix W = new Matrix(trainingLen, termLen);
115
116                addRow(F, W, 0, fv, first.getAnnotations());
117                for (int i = 1; i < trainingLen; i++) {
118                        addRow(F, W, i, data.get(i));
119                }
120
121                final Matrix pinvF = PseudoInverse.pseudoInverse(F, k);
122                transform = pinvF.times(W);
123        }
124
125        private void addRow(Matrix F, Matrix W, int r, Annotated<OBJECT, ANNOTATION> data) {
126                final double[] fv = extractor.extractFeature(data.getObject()).asDoubleVector();
127
128                addRow(F, W, r, fv, data.getAnnotations());
129        }
130
131        private void addRow(Matrix F, Matrix W, int r, double[] fv, Collection<ANNOTATION> annotations) {
132                for (int j = 0; j < F.getColumnDimension(); j++)
133                        F.getArray()[r][j] = fv[j];
134
135                for (final ANNOTATION t : annotations) {
136                        W.getArray()[r][terms.indexOf(t)]++;
137                }
138        }
139
140        @Override
141        public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT image) {
142                final double[] fv = extractor.extractFeature(image).asDoubleVector();
143
144                final Matrix F = new Matrix(new double[][] { fv });
145
146                final Matrix res = F.times(transform);
147
148                final List<ScoredAnnotation<ANNOTATION>> ann = new ArrayList<ScoredAnnotation<ANNOTATION>>();
149                for (int i = 0; i < terms.size(); i++) {
150                        ann.add(new ScoredAnnotation<ANNOTATION>(terms.get(i), (float) res.get(0, i)));
151                }
152
153                Collections.sort(ann, new Comparator<ScoredAnnotation<ANNOTATION>>() {
154                        @Override
155                        public int compare(ScoredAnnotation<ANNOTATION> o1, ScoredAnnotation<ANNOTATION> o2) {
156                                return o1.confidence < o2.confidence ? 1 : -1;
157                        }
158                });
159
160                return ann;
161        }
162
163        @Override
164        public Set<ANNOTATION> getAnnotations() {
165                return new HashSet<ANNOTATION>(terms);
166        }
167}