001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.classification;
031
032import org.openimaj.util.function.Operation;
033import org.openimaj.util.pair.ObjectFloatPair;
034import org.openimaj.util.parallel.Parallel;
035import org.openimaj.util.parallel.Parallel.IntRange;
036
037public class StumpClassifier {
038        public static class WeightedLearner {
039                // Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)]
040                // and h(x) = classifier.sign * (2 * [xclassifier.dimension >
041                // classifier.threshold] - 1)
042                public ObjectFloatPair<StumpClassifier> learn(final LabelledDataProvider trainingSet, final float[] _weights) {
043                        final StumpClassifier classifier = new StumpClassifier();
044
045                        // Search for minimum training set error
046                        final float[] minimumError = { Float.POSITIVE_INFINITY };
047
048                        final boolean[] classes = trainingSet.getClasses();
049                        final int nInstances = trainingSet.numInstances();
050
051                        // Determine total potential error
052                        float totalErrorC = 0.0f;
053                        for (int i = 0; i < nInstances; i++)
054                                totalErrorC += _weights[i];
055                        final float totalError = totalErrorC;
056
057                        // Initialise search error
058                        float initialErrorC = 0.0f;
059                        for (int i = 0; i < nInstances; i++)
060                                initialErrorC += !classes[i] ? _weights[i] : 0.0;
061                                final float initialError = initialErrorC;
062
063                                // Loop over possible dimensions
064                                // for (int d = 0; d < trainingSet.numFeatures(); d++) {
065                                Parallel.forRange(0, trainingSet.numDimensions(), 1, new Operation<IntRange>() {
066                                        @Override
067                                        public void perform(IntRange rng) {
068                                                final StumpClassifier currClassifier = new StumpClassifier();
069                                                currClassifier.dimension = -1;
070                                                currClassifier.threshold = Float.NaN;
071                                                currClassifier.sign = 0;
072
073                                                float currMinimumError = Float.POSITIVE_INFINITY;
074
075                                                for (int d = rng.start; d < rng.stop; d += rng.incr) {
076                                                        // Pre-sort data-items in dimension for efficient
077                                                        // threshold
078                                                        // search
079                                                        final float[] data = trainingSet.getFeatureResponse(d);
080                                                        final int[] indices = trainingSet.getSortedResponseIndices(d);
081
082                                                        // Initialise search error
083                                                        float currentError = initialError;
084
085                                                        // Search through the sorted list to determine best
086                                                // threshold
087                                                        for (int i = 0; i < nInstances - 1; i++) {
088                                                                // Update current error
089                                                                final int index = indices[i];
090                                                                if (classes[index])
091                                                                        currentError += _weights[index];
092                                                                else
093                                                                        currentError -= _weights[index];
094
095                                                                // Check for repeated values
096                                                                if (data[indices[i]] == data[indices[i + 1]])
097                                                                        continue;
098
099                                                                // Compute the test threshold - maximises the margin
100                                                                // between potential thresholds
101                                                                final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f;
102
103                                                                // Compare to current best
104                                                                if (currentError < currMinimumError)
105                                                                {
106                                                                        // Good classifier with classifier.sign = +1
107                                                                        currMinimumError = currentError;
108                                                                        currClassifier.dimension = d;
109                                                                        currClassifier.threshold = testThreshold;
110                                                                        currClassifier.sign = +1;
111                                                                }
112                                                                if ((totalError - currentError) < currMinimumError)
113                                                                {
114                                                                        // Good classifier with classifier.sign = -1
115                                                                        currMinimumError = (totalError - currentError);
116                                                                        currClassifier.dimension = d;
117                                                                        currClassifier.threshold = testThreshold;
118                                                                        currClassifier.sign = -1;
119                                                                }
120                                                        }
121                                                }
122
123                                                synchronized (classifier) {
124                                                        if (currMinimumError < minimumError[0]) {
125                                                                minimumError[0] = currMinimumError;
126                                                                classifier.dimension = currClassifier.dimension;
127                                                                classifier.sign = currClassifier.sign;
128                                                                classifier.threshold = currClassifier.threshold;
129                                                        }
130                                                }
131                                        }
132                                });
133
134                                return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]);
135                }
136        }
137
138        /**
139         * The dimension of the feature on which this stump operates
140         */
141        public int dimension;
142
143        /**
144         * The threshold which is the feature is tested against
145         */
146        public float threshold;
147
148        /**
149         * The sign of the stump (determines which side of the threshold corresponds
150         * to positive)
151         */
152        public int sign;
153
154        public boolean classify(float[] instanceFeature) {
155                return (instanceFeature[dimension] > threshold ? sign : -sign) == 1;
156        }
157
158        public boolean classify(float f) {
159                return (f > threshold ? sign : -sign) == 1;
160        }
161}