001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix;
031
032import gov.sandia.cognition.learning.data.DefaultTargetEstimatePair;
033import gov.sandia.cognition.learning.data.TargetEstimatePair;
034import gov.sandia.cognition.learning.performance.categorization.ConfusionMatrixPerformanceEvaluator;
035
036import java.util.ArrayList;
037import java.util.HashSet;
038import java.util.LinkedHashSet;
039import java.util.List;
040import java.util.Map;
041import java.util.Set;
042
043import org.openimaj.experiment.evaluation.classification.ClassificationAnalyser;
044import org.openimaj.experiment.evaluation.classification.ClassificationResult;
045import org.openimaj.experiment.evaluation.classification.Classifier;
046
047/**
048 * A {@link ClassificationAnalyser} that creates Confusion Matrices.
049 *
050 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
051 *
052 * @param <CLASS>
053 *            The type of classes produced by the {@link Classifier}
054 * @param <OBJECT>
055 *            The type of object classifed by the {@link Classifier}
056 */
057public class CMAnalyser<OBJECT, CLASS>
058implements ClassificationAnalyser<
059CMResult<CLASS>,
060CLASS,
061OBJECT>
062{
063        /**
064         * Strategies for building confusion matrices
065         *
066         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
067         */
068        public static enum Strategy {
069                /**
070                 * Strategy to use when there is exactly one actual class and one
071                 * predicted class.
072                 */
073                SINGLE {
074                        @Override
075                        protected <CLASS> void add(
076                                        List<TargetEstimatePair<CLASS, CLASS>> data,
077                                        Set<CLASS> predicted, Set<CLASS> actual)
078                        {
079                                data.add(DefaultTargetEstimatePair.create(
080                                                actual.size() == 0 ? null : new ArrayList<CLASS>(actual).get(0),
081                                                                predicted.size() == 0 ? null : new ArrayList<CLASS>(predicted).get(0)
082                                                ));
083                        }
084                },
085                /**
086                 * Strategy for multiple possible actual classes and predicted classes.
087                 * Deals with:
088                 * <ol>
089                 * <li>true positives (a class present in both the predicted and actual
090                 * set</li>
091                 * <li>false positives (a predicted class not being in the actual set)</li>
092                 * <li>false negatives (an actual class not being in the predicted set)</li>
093                 * </ol>
094                 * False positives and negatives are dealt with by using
095                 * <code>null</code> values for the actual/predicted class respectively.
096                 */
097                MULTIPLE {
098                        @Override
099                        protected <CLASS> void add(
100                                        List<TargetEstimatePair<CLASS, CLASS>> data,
101                                        Set<CLASS> predicted, Set<CLASS> actual)
102                        {
103                                final HashSet<CLASS> allClasses = new HashSet<CLASS>();
104                                allClasses.addAll(predicted);
105                                allClasses.addAll(actual);
106
107                                for (final CLASS clz : allClasses) {
108                                        final CLASS target = actual.contains(clz) ? clz : null;
109                                        final CLASS estimate = predicted.contains(clz) ? clz : null;
110
111                                        data.add(DefaultTargetEstimatePair.create(target, estimate));
112                                }
113                        }
114                },
115                /**
116                 * Strategy for multiple possible actual classes and predicted classes
117                 * in the case the predictions and actual classes are ordered and there
118                 * is a one-to-one correspondence.
119                 * <p>
120                 * A {@link RuntimeException} will be thrown if the sets are not the
121                 * same size and both instances of {@link LinkedHashSet}.
122                 */
123                MULTIPLE_ORDERED {
124                        @SuppressWarnings("unchecked")
125                        @Override
126                        protected <CLASS> void add(
127                                        List<TargetEstimatePair<CLASS, CLASS>> data,
128                                        Set<CLASS> predicted, Set<CLASS> actual)
129                        {
130                                final LinkedHashSet<CLASS> op = (LinkedHashSet<CLASS>) predicted;
131                                final LinkedHashSet<CLASS> ap = (LinkedHashSet<CLASS>) actual;
132
133                                if (op.size() != ap.size())
134                                        throw new RuntimeException("Sets are not the same size!");
135
136                                final Object[] opa = op.toArray();
137                                final Object[] apa = ap.toArray();
138
139                                for (int i = 0; i < opa.length; i++)
140                                        data.add(new DefaultTargetEstimatePair<CLASS, CLASS>((CLASS) opa[i], (CLASS) apa[i]));
141                        }
142                };
143
144                protected abstract <CLASS> void add(List<TargetEstimatePair<CLASS, CLASS>> data, Set<CLASS> predicted,
145                                Set<CLASS> actual);
146        }
147
148        protected Strategy strategy;
149        ConfusionMatrixPerformanceEvaluator<?, CLASS> eval = new ConfusionMatrixPerformanceEvaluator<Object, CLASS>();
150
151        /**
152         * Construct with the given strategy for building the confusion matrix
153         *
154         * @param strategy
155         *            the strategy
156         */
157        public CMAnalyser(Strategy strategy) {
158                this.strategy = strategy;
159        }
160
161        @Override
162        public CMResult<CLASS> analyse(
163                        Map<OBJECT, ClassificationResult<CLASS>> predicted,
164                        Map<OBJECT, Set<CLASS>> actual)
165                        {
166                final List<TargetEstimatePair<CLASS, CLASS>> data = new ArrayList<TargetEstimatePair<CLASS, CLASS>>();
167
168                for (final OBJECT obj : predicted.keySet()) {
169                        final Set<CLASS> pclasses = predicted.get(obj).getPredictedClasses();
170                        final Set<CLASS> aclasses = actual.get(obj);
171
172                        strategy.add(data, pclasses, aclasses);
173                }
174
175                return new CMResult<CLASS>(eval.evaluatePerformance(data));
176                        }
177}