001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.annotation.basic;
031
032import gnu.trove.list.array.TDoubleArrayList;
033import gnu.trove.map.hash.TIntIntHashMap;
034import gnu.trove.map.hash.TObjectIntHashMap;
035import gnu.trove.procedure.TObjectIntProcedure;
036
037import java.util.ArrayList;
038import java.util.Collection;
039import java.util.HashSet;
040import java.util.List;
041import java.util.Set;
042
043import org.openimaj.ml.annotation.Annotated;
044import org.openimaj.ml.annotation.BatchAnnotator;
045import org.openimaj.ml.annotation.ScoredAnnotation;
046import org.openimaj.ml.annotation.basic.util.NumAnnotationsChooser;
047
048import cern.jet.random.Empirical;
049import cern.jet.random.EmpiricalWalker;
050import cern.jet.random.engine.MersenneTwister;
051
052/**
053 * Annotator that randomly assigns annotations, but takes account of the prior
054 * probability of each annotation based on the proportion of times it occurred
055 * in training. Annotations that occurred less in training are less likely to be
056 * picked. The number of annotations produced is set by the type of
057 * {@link NumAnnotationsChooser} used.
058 * 
059 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
060 * 
061 * @param <OBJECT>
062 *            Type of object being annotated
063 * @param <ANNOTATION>
064 *            Type of annotation.
065 */
066public class IndependentPriorRandomAnnotator<OBJECT, ANNOTATION> extends BatchAnnotator<OBJECT, ANNOTATION> {
067        protected List<ANNOTATION> annotations;
068        protected NumAnnotationsChooser numAnnotations;
069        protected EmpiricalWalker annotationProbability;
070
071        /**
072         * Construct with the given {@link NumAnnotationsChooser} to determine how
073         * many annotations are produced by calls to {@link #annotate(Object)}.
074         * 
075         * @param chooser
076         *            the {@link NumAnnotationsChooser} to use.
077         */
078        public IndependentPriorRandomAnnotator(NumAnnotationsChooser chooser) {
079                this.numAnnotations = chooser;
080        }
081
082        @Override
083        public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
084                final TIntIntHashMap nAnnotationCounts = new TIntIntHashMap();
085                final TObjectIntHashMap<ANNOTATION> annotationCounts = new TObjectIntHashMap<ANNOTATION>();
086                int maxVal = 0;
087
088                for (final Annotated<OBJECT, ANNOTATION> sample : data) {
089                        final Collection<ANNOTATION> annos = sample.getAnnotations();
090
091                        for (final ANNOTATION s : annos) {
092                                annotationCounts.adjustOrPutValue(s, 1, 1);
093                        }
094
095                        nAnnotationCounts.adjustOrPutValue(annos.size(), 1, 1);
096
097                        if (annos.size() > maxVal)
098                                maxVal = annos.size();
099                }
100
101                // build distribution and rng for each annotation
102                annotations = new ArrayList<ANNOTATION>();
103                final TDoubleArrayList probs = new TDoubleArrayList();
104                annotationCounts.forEachEntry(new TObjectIntProcedure<ANNOTATION>() {
105                        @Override
106                        public boolean execute(ANNOTATION a, int b) {
107                                annotations.add(a);
108                                probs.add(b);
109                                return true;
110                        }
111                });
112                annotationProbability = new EmpiricalWalker(probs.toArray(), Empirical.NO_INTERPOLATION, new MersenneTwister());
113
114                numAnnotations.train(data);
115        }
116
117        @Override
118        public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT image) {
119                final int nAnnotations = numAnnotations.numAnnotations();
120
121                final List<ScoredAnnotation<ANNOTATION>> annos = new ArrayList<ScoredAnnotation<ANNOTATION>>();
122
123                for (int i = 0; i < nAnnotations; i++) {
124                        final int annotationIdx = annotationProbability.nextInt();
125                        annos.add(new ScoredAnnotation<ANNOTATION>(annotations.get(annotationIdx), (float) annotationProbability
126                                        .pdf(annotationIdx + 1)));
127                }
128
129                return annos;
130        }
131
132        @Override
133        public Set<ANNOTATION> getAnnotations() {
134                return new HashSet<ANNOTATION>(annotations);
135        }
136}