View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.classification.boosting;
31  
32  import java.util.ArrayList;
33  import java.util.List;
34  
35  import org.openimaj.ml.classification.LabelledDataProvider;
36  import org.openimaj.ml.classification.StumpClassifier;
37  import org.openimaj.util.pair.ObjectFloatPair;
38  
39  public class AdaBoost {
40  	StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner();
41  
42  	public List<ObjectFloatPair<StumpClassifier>> learn(LabelledDataProvider trainingSet, int numberOfRounds) {
43  		// Initialise weights
44  		final float[] weights = new float[trainingSet.numInstances()];
45  		for (int i = 0; i < trainingSet.numInstances(); i++)
46  			weights[i] = 1.0f / trainingSet.numInstances();
47  
48  		final boolean[] actualClasses = trainingSet.getClasses();
49  
50  		final List<ObjectFloatPair<StumpClassifier>> ensemble = new ArrayList<ObjectFloatPair<StumpClassifier>>();
51  
52  		// Perform the learning
53  		for (int t = 0; t < numberOfRounds; t++) {
54  			System.out.println("Iteration: " + t);
55  
56  			// Create the weak learner and train it
57  			final ObjectFloatPair<StumpClassifier> h = factory.learn(trainingSet, weights);
58  
59  			// Compute the classifications and training error
60  			final boolean[] hClassification = new boolean[trainingSet.numInstances()];
61  			final float[] responses = trainingSet.getFeatureResponse(h.first.dimension);
62  			double epsilon = 0.0;
63  			for (int i = 0; i < trainingSet.numInstances(); i++) {
64  				hClassification[i] = h.first.classify(responses[i]);
65  				epsilon += hClassification[i] != actualClasses[i] ? weights[i] : 0.0;
66  			}
67  
68  			// Check stopping condition
69  			if (epsilon >= 0.5)
70  				break;
71  
72  			// Calculate alpha
73  			final float alpha = (float) (0.5 * Math.log((1 - epsilon) / epsilon));
74  
75  			// Update the weights
76  			float weightsSum = 0.0f;
77  			for (int i = 0; i < trainingSet.numInstances(); i++) {
78  				weights[i] *= Math.exp(-alpha * (actualClasses[i] ? 1 : -1) * (hClassification[i] ? 1 : -1));
79  				weightsSum += weights[i];
80  			}
81  
82  			// Normalise
83  			for (int i = 0; i < trainingSet.numInstances(); i++)
84  				weights[i] /= weightsSum;
85  
86  			// Store the weak learner and alpha value
87  			ensemble.add(new ObjectFloatPair<StumpClassifier>(h.first, alpha));
88  
89  			// Break if perfectly classifying data
90  			if (epsilon == 0.0)
91  				break;
92  		}
93  
94  		return ensemble;
95  	}
96  
97  	public void printClassificationQuality(LabelledDataProvider data, List<ObjectFloatPair<StumpClassifier>> ensemble,
98  			float threshold)
99  	{
100 		int tp = 0;
101 		int fn = 0;
102 		int tn = 0;
103 		int fp = 0;
104 
105 		final int ninstances = data.numInstances();
106 		final boolean[] classes = data.getClasses();
107 		for (int i = 0; i < ninstances; i++) {
108 			final float[] feature = data.getInstanceFeature(i);
109 
110 			final boolean predicted = AdaBoost.classify(feature, ensemble, threshold);
111 			final boolean actual = classes[i];
112 
113 			if (actual) {
114 				if (predicted)
115 					tp++; // TP
116 				else
117 					fn++; // FN
118 			} else {
119 				if (predicted)
120 					fp++; // FP
121 				else
122 					tn++; // TN
123 			}
124 		}
125 
126 		System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", tp, fn, fp, tn);
127 
128 		final float fpr = (float) fp / (float) (fp + tn);
129 		final float tpr = (float) tp / (float) (tp + fn);
130 
131 		System.out.format("FPR: %2.2f\tTPR: %2.2f\n", fpr, tpr);
132 	}
133 
134 	public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble) {
135 		double classification = 0.0;
136 
137 		// Call the weak learner classify methods and combine results
138 		for (int t = 0; t < ensemble.size(); t++)
139 			classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1);
140 
141 		// Return the thresholded classification
142 		return classification > 0.0 ? true : false;
143 	}
144 
145 	public static boolean classify(float[] data, List<ObjectFloatPair<StumpClassifier>> ensemble, float threshold) {
146 		double classification = 0.0;
147 
148 		// Call the weak learner classify methods and combine results
149 		for (int t = 0; t < ensemble.size(); t++)
150 			classification += ensemble.get(t).second * (ensemble.get(t).first.classify(data) ? 1 : -1);
151 
152 		// Return the thresholded classification
153 		return classification > threshold ? true : false;
154 	}
155 }