001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.evaluation.classification; 031 032import gnu.trove.map.hash.TObjectDoubleHashMap; 033import gnu.trove.procedure.TObjectDoubleProcedure; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.LinkedHashSet; 039import java.util.List; 040import java.util.Set; 041 042import org.openimaj.util.pair.ObjectDoublePair; 043 044/** 045 * A basic implementation of a {@link ClassificationResult} that internally 046 * maintains a map of classes to confidences. 047 * <p> 048 * A threshold is used to determine whether a class has a high-enough confidence 049 * to be considered part of the predicted set of classes. 050 * 051 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 052 * 053 * @param <CLASS> 054 * type of class predicted by the {@link Classifier} 055 */ 056public class BasicClassificationResult<CLASS> implements ClassificationResult<CLASS> { 057 private final TObjectDoubleHashMap<CLASS> data = new TObjectDoubleHashMap<CLASS>(); 058 private double threshold = 0; 059 060 /** 061 * Construct with a default threshold of 0. 062 */ 063 public BasicClassificationResult() { 064 065 } 066 067 /** 068 * Construct with the given threshold. 069 * 070 * @param threshold 071 * the threshold 072 */ 073 public BasicClassificationResult(double threshold) { 074 this.threshold = threshold; 075 } 076 077 /** 078 * Add a class/confidence pair. 079 * 080 * @param clz 081 * the class 082 * @param confidence 083 * the confidence 084 */ 085 public void put(CLASS clz, double confidence) { 086 data.put(clz, confidence); 087 } 088 089 @Override 090 public double getConfidence(CLASS clazz) { 091 return data.get(clazz); 092 } 093 094 @Override 095 public Set<CLASS> getPredictedClasses() { 096 // predicted classes are sorted by decreasing confidence 097 098 final List<ObjectDoublePair<CLASS>> toSort = new ArrayList<ObjectDoublePair<CLASS>>(); 099 100 data.forEachEntry(new TObjectDoubleProcedure<CLASS>() { 101 @Override 102 public boolean execute(CLASS key, double confidence) { 103 if (confidence > threshold) 104 toSort.add(new ObjectDoublePair<CLASS>(key, confidence)); 105 return true; 106 } 107 }); 108 109 Collections.sort(toSort, new Comparator<ObjectDoublePair<CLASS>>() { 110 @Override 111 public int compare(ObjectDoublePair<CLASS> o1, ObjectDoublePair<CLASS> o2) { 112 return -1 * Double.compare(o1.second, o2.second); 113 } 114 }); 115 116 final Set<CLASS> keys = new LinkedHashSet<CLASS>(toSort.size()); 117 118 for (final ObjectDoublePair<CLASS> p : toSort) 119 keys.add(p.first); 120 121 return keys; 122 } 123}