001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.classification.boosting; 031 032import java.util.ArrayList; 033import java.util.List; 034 035import org.openimaj.ml.classification.LabelledDataProvider; 036import org.openimaj.ml.classification.StumpClassifier; 037import org.openimaj.util.pair.ObjectFloatPair; 038 039public class AdaBoost { 040 StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner(); 041 042 public List<ObjectFloatPair<StumpClassifier>> learn(LabelledDataProvider trainingSet, int numberOfRounds) { 043 // Initialise weights 044 final float[] weights = new float[trainingSet.numInstances()]; 045 for (int i = 0; i < trainingSet.numInstances(); i++) 046 weights[i] = 1.0f / trainingSet.numInstances(); 047 048 final boolean[] actualClasses = trainingSet.getClasses(); 049 050 final List<ObjectFloatPair<StumpClassifier>> ensemble = new ArrayList<ObjectFloatPair<StumpClassifier>>(); 051 052 // Perform the learning 053 for (int t = 0; t < numberOfRounds; t++) { 054 System.out.println("Iteration: " + t); 055 056 // Create the weak learner and train it 057 final ObjectFloatPair<StumpClassifier> h = factory.learn(trainingSet, weights); 058 059 // Compute the classifications and training error 060 final boolean[] hClassification = new boolean[trainingSet.numInstances()]; 061 final float[] responses = trainingSet.getFeatureResponse(h.first.dimension); 062 double epsilon = 0.0; 063 for (int i = 0; i < trainingSet.numInstances(); i++) { 064 hClassification[i] = h.first.classify(responses[i]); 065 epsilon += hClassification[i] != actualClasses[i] ? weights[i] : 0.0; 066 } 067 068 // Check stopping condition 069 if (epsilon >= 0.5) 070 break; 071 072 // Calculate alpha 073 final float alpha = (float) (0.5 * Math.log((1 - epsilon) / epsilon)); 074 075 // Update the weights 076 float weightsSum = 0.0f; 077 for (int i = 0; i < trainingSet.numInstances(); i++) { 078 weights[i] *= Math.exp(-alpha * (actualClasses[i] ? 1 : -1) * (hClassification[i] ? 1 : -1)); 079 weightsSum += weights[i]; 080 } 081 082 // Normalise 083 for (int i = 0; i < trainingSet.numInstances(); i++) 084 weights[i] /= weightsSum; 085 086 // Store the weak learner and alpha value 087 ensemble.add(new ObjectFloatPair<StumpClassifier>(h.first, alpha)); 088 089 // Break if perfectly classifying data 090 if (epsilon == 0.0) 091 break; 092 } 093 094 return ensemble; 095 } 096 097 public void printClassificationQuality(LabelledDataProvider data, List<ObjectFloatPair<StumpClassifier>> ensemble, 098 float threshold) 099 { 100 int tp = 0; 101 int fn = 0; 102 int tn = 0; 103 int fp = 0; 104 105 final int ninstances = data.numInstances(); 106 final boolean[] classes = data.getClasses(); 107 for (int i = 0; i < ninstances; i++) { 108 final float[] feature = data.getInstanceFeature(i); 109 110 final boolean predicted = AdaBoost.classify(feature, ensemble, threshold); 111 final boolean actual = classes[i]; 112 113 if (actual) { 114 if (predicted) 115 tp++; // TP 116 else 117 fn++; // FN 118 } else { 119 if (predicted) 120 fp++; // FP 121 else 122 tn++; // TN 123 } 124 } 125 126 System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", tp, fn, fp, tn); 127 128 final float fpr = (float) fp / (float) (fp + tn); 129 final float tpr = (float) tp / (float) (tp + fn); 130 131 System.out.format("FPR: %2.2f\tTPR: %2.2f\n", fpr, tpr); 132 } 133 134 public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble) { 135 double classification = 0.0; 136 137 // Call the weak learner classify methods and combine results 138 for (int t = 0; t < ensemble.size(); t++) 139 classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1); 140 141 // Return the thresholded classification 142 return classification > 0.0 ? true : false; 143 } 144 145 public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble, float threshold) { 146 double classification = 0.0; 147 148 // Call the weak learner classify methods and combine results 149 for (int t = 0; t < ensemble.size(); t++) 150 classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1); 151 152 // Return the thresholded classification 153 return classification > threshold ? true : false; 154 } 155}