001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.classification.boosting;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.ml.classification.LabelledDataProvider;
036import org.openimaj.ml.classification.StumpClassifier;
037import org.openimaj.util.pair.ObjectFloatPair;
038
039public class AdaBoost {
040        StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner();
041
042        public List<ObjectFloatPair<StumpClassifier>> learn(LabelledDataProvider trainingSet, int numberOfRounds) {
043                // Initialise weights
044                final float[] weights = new float[trainingSet.numInstances()];
045                for (int i = 0; i < trainingSet.numInstances(); i++)
046                        weights[i] = 1.0f / trainingSet.numInstances();
047
048                final boolean[] actualClasses = trainingSet.getClasses();
049
050                final List<ObjectFloatPair<StumpClassifier>> ensemble = new ArrayList<ObjectFloatPair<StumpClassifier>>();
051
052                // Perform the learning
053                for (int t = 0; t < numberOfRounds; t++) {
054                        System.out.println("Iteration: " + t);
055
056                        // Create the weak learner and train it
057                        final ObjectFloatPair<StumpClassifier> h = factory.learn(trainingSet, weights);
058
059                        // Compute the classifications and training error
060                        final boolean[] hClassification = new boolean[trainingSet.numInstances()];
061                        final float[] responses = trainingSet.getFeatureResponse(h.first.dimension);
062                        double epsilon = 0.0;
063                        for (int i = 0; i < trainingSet.numInstances(); i++) {
064                                hClassification[i] = h.first.classify(responses[i]);
065                                epsilon += hClassification[i] != actualClasses[i] ? weights[i] : 0.0;
066                        }
067
068                        // Check stopping condition
069                        if (epsilon >= 0.5)
070                                break;
071
072                        // Calculate alpha
073                        final float alpha = (float) (0.5 * Math.log((1 - epsilon) / epsilon));
074
075                        // Update the weights
076                        float weightsSum = 0.0f;
077                        for (int i = 0; i < trainingSet.numInstances(); i++) {
078                                weights[i] *= Math.exp(-alpha * (actualClasses[i] ? 1 : -1) * (hClassification[i] ? 1 : -1));
079                                weightsSum += weights[i];
080                        }
081
082                        // Normalise
083                        for (int i = 0; i < trainingSet.numInstances(); i++)
084                                weights[i] /= weightsSum;
085
086                        // Store the weak learner and alpha value
087                        ensemble.add(new ObjectFloatPair<StumpClassifier>(h.first, alpha));
088
089                        // Break if perfectly classifying data
090                        if (epsilon == 0.0)
091                                break;
092                }
093
094                return ensemble;
095        }
096
097        public void printClassificationQuality(LabelledDataProvider data, List<ObjectFloatPair<StumpClassifier>> ensemble,
098                        float threshold)
099        {
100                int tp = 0;
101                int fn = 0;
102                int tn = 0;
103                int fp = 0;
104
105                final int ninstances = data.numInstances();
106                final boolean[] classes = data.getClasses();
107                for (int i = 0; i < ninstances; i++) {
108                        final float[] feature = data.getInstanceFeature(i);
109
110                        final boolean predicted = AdaBoost.classify(feature, ensemble, threshold);
111                        final boolean actual = classes[i];
112
113                        if (actual) {
114                                if (predicted)
115                                        tp++; // TP
116                                else
117                                        fn++; // FN
118                        } else {
119                                if (predicted)
120                                        fp++; // FP
121                                else
122                                        tn++; // TN
123                        }
124                }
125
126                System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", tp, fn, fp, tn);
127
128                final float fpr = (float) fp / (float) (fp + tn);
129                final float tpr = (float) tp / (float) (tp + fn);
130
131                System.out.format("FPR: %2.2f\tTPR: %2.2f\n", fpr, tpr);
132        }
133
134        public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble) {
135                double classification = 0.0;
136
137                // Call the weak learner classify methods and combine results
138                for (int t = 0; t < ensemble.size(); t++)
139                        classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1);
140
141                // Return the thresholded classification
142                return classification > 0.0 ? true : false;
143        }
144
145        public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble, float threshold) {
146                double classification = 0.0;
147
148                // Call the weak learner classify methods and combine results
149                for (int t = 0; t < ensemble.size(); t++)
150                        classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1);
151
152                // Return the thresholded classification
153                return classification > threshold ? true : false;
154        }
155}