View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.classification;
31  
32  import org.openimaj.util.function.Operation;
33  import org.openimaj.util.pair.ObjectFloatPair;
34  import org.openimaj.util.parallel.Parallel;
35  import org.openimaj.util.parallel.Parallel.IntRange;
36  
37  public class StumpClassifier {
38  	public static class WeightedLearner {
39  		// Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)]
40  		// and h(x) = classifier.sign * (2 * [xclassifier.dimension >
41  		// classifier.threshold] - 1)
42  		public ObjectFloatPair<StumpClassifier> learn(final LabelledDataProvider trainingSet, final float[] _weights) {
43  			final StumpClassifier classifier = new StumpClassifier();
44  
45  			// Search for minimum training set error
46  			final float[] minimumError = { Float.POSITIVE_INFINITY };
47  
48  			final boolean[] classes = trainingSet.getClasses();
49  			final int nInstances = trainingSet.numInstances();
50  
51  			// Determine total potential error
52  			float totalErrorC = 0.0f;
53  			for (int i = 0; i < nInstances; i++)
54  				totalErrorC += _weights[i];
55  			final float totalError = totalErrorC;
56  
57  			// Initialise search error
58  			float initialErrorC = 0.0f;
59  			for (int i = 0; i < nInstances; i++)
60  				initialErrorC += !classes[i] ? _weights[i] : 0.0;
61  				final float initialError = initialErrorC;
62  
63  				// Loop over possible dimensions
64  				// for (int d = 0; d < trainingSet.numFeatures(); d++) {
65  				Parallel.forRange(0, trainingSet.numDimensions(), 1, new Operation<IntRange>() {
66  					@Override
67  					public void perform(IntRange rng) {
68  						final StumpClassifier currClassifier = new StumpClassifier();
69  						currClassifier.dimension = -1;
70  						currClassifier.threshold = Float.NaN;
71  						currClassifier.sign = 0;
72  
73  						float currMinimumError = Float.POSITIVE_INFINITY;
74  
75  						for (int d = rng.start; d < rng.stop; d += rng.incr) {
76  							// Pre-sort data-items in dimension for efficient
77  							// threshold
78  							// search
79  							final float[] data = trainingSet.getFeatureResponse(d);
80  							final int[] indices = trainingSet.getSortedResponseIndices(d);
81  
82  							// Initialise search error
83  							float currentError = initialError;
84  
85  							// Search through the sorted list to determine best
86  						// threshold
87  							for (int i = 0; i < nInstances - 1; i++) {
88  								// Update current error
89  								final int index = indices[i];
90  								if (classes[index])
91  									currentError += _weights[index];
92  								else
93  									currentError -= _weights[index];
94  
95  								// Check for repeated values
96  								if (data[indices[i]] == data[indices[i + 1]])
97  									continue;
98  
99  								// Compute the test threshold - maximises the margin
100 								// between potential thresholds
101 								final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f;
102 
103 								// Compare to current best
104 								if (currentError < currMinimumError)
105 								{
106 									// Good classifier with classifier.sign = +1
107 									currMinimumError = currentError;
108 									currClassifier.dimension = d;
109 									currClassifier.threshold = testThreshold;
110 									currClassifier.sign = +1;
111 								}
112 								if ((totalError - currentError) < currMinimumError)
113 								{
114 									// Good classifier with classifier.sign = -1
115 									currMinimumError = (totalError - currentError);
116 									currClassifier.dimension = d;
117 									currClassifier.threshold = testThreshold;
118 									currClassifier.sign = -1;
119 								}
120 							}
121 						}
122 
123 						synchronized (classifier) {
124 							if (currMinimumError < minimumError[0]) {
125 								minimumError[0] = currMinimumError;
126 								classifier.dimension = currClassifier.dimension;
127 								classifier.sign = currClassifier.sign;
128 								classifier.threshold = currClassifier.threshold;
129 							}
130 						}
131 					}
132 				});
133 
134 				return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]);
135 		}
136 	}
137 
138 	/**
139 	 * The dimension of the feature on which this stump operates
140 	 */
141 	public int dimension;
142 
143 	/**
144 	 * The threshold which is the feature is tested against
145 	 */
146 	public float threshold;
147 
148 	/**
149 	 * The sign of the stump (determines which side of the threshold corresponds
150 	 * to positive)
151 	 */
152 	public int sign;
153 
154 	public boolean classify(float[] instanceFeature) {
155 		return (instanceFeature[dimension] > threshold ? sign : -sign) == 1;
156 	}
157 
158 	public boolean classify(float f) {
159 		return (f > threshold ? sign : -sign) == 1;
160 	}
161 }