001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.classification; 031 032import org.openimaj.util.function.Operation; 033import org.openimaj.util.pair.ObjectFloatPair; 034import org.openimaj.util.parallel.Parallel; 035import org.openimaj.util.parallel.Parallel.IntRange; 036 037public class StumpClassifier { 038 public static class WeightedLearner { 039 // Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)] 040 // and h(x) = classifier.sign * (2 * [xclassifier.dimension > 041 // classifier.threshold] - 1) 042 public ObjectFloatPair<StumpClassifier> learn(final LabelledDataProvider trainingSet, final float[] _weights) { 043 final StumpClassifier classifier = new StumpClassifier(); 044 045 // Search for minimum training set error 046 final float[] minimumError = { Float.POSITIVE_INFINITY }; 047 048 final boolean[] classes = trainingSet.getClasses(); 049 final int nInstances = trainingSet.numInstances(); 050 051 // Determine total potential error 052 float totalErrorC = 0.0f; 053 for (int i = 0; i < nInstances; i++) 054 totalErrorC += _weights[i]; 055 final float totalError = totalErrorC; 056 057 // Initialise search error 058 float initialErrorC = 0.0f; 059 for (int i = 0; i < nInstances; i++) 060 initialErrorC += !classes[i] ? _weights[i] : 0.0; 061 final float initialError = initialErrorC; 062 063 // Loop over possible dimensions 064 // for (int d = 0; d < trainingSet.numFeatures(); d++) { 065 Parallel.forRange(0, trainingSet.numDimensions(), 1, new Operation<IntRange>() { 066 @Override 067 public void perform(IntRange rng) { 068 final StumpClassifier currClassifier = new StumpClassifier(); 069 currClassifier.dimension = -1; 070 currClassifier.threshold = Float.NaN; 071 currClassifier.sign = 0; 072 073 float currMinimumError = Float.POSITIVE_INFINITY; 074 075 for (int d = rng.start; d < rng.stop; d += rng.incr) { 076 // Pre-sort data-items in dimension for efficient 077 // threshold 078 // search 079 final float[] data = trainingSet.getFeatureResponse(d); 080 final int[] indices = trainingSet.getSortedResponseIndices(d); 081 082 // Initialise search error 083 float currentError = initialError; 084 085 // Search through the sorted list to determine best 086 // threshold 087 for (int i = 0; i < nInstances - 1; i++) { 088 // Update current error 089 final int index = indices[i]; 090 if (classes[index]) 091 currentError += _weights[index]; 092 else 093 currentError -= _weights[index]; 094 095 // Check for repeated values 096 if (data[indices[i]] == data[indices[i + 1]]) 097 continue; 098 099 // Compute the test threshold - maximises the margin 100 // between potential thresholds 101 final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f; 102 103 // Compare to current best 104 if (currentError < currMinimumError) 105 { 106 // Good classifier with classifier.sign = +1 107 currMinimumError = currentError; 108 currClassifier.dimension = d; 109 currClassifier.threshold = testThreshold; 110 currClassifier.sign = +1; 111 } 112 if ((totalError - currentError) < currMinimumError) 113 { 114 // Good classifier with classifier.sign = -1 115 currMinimumError = (totalError - currentError); 116 currClassifier.dimension = d; 117 currClassifier.threshold = testThreshold; 118 currClassifier.sign = -1; 119 } 120 } 121 } 122 123 synchronized (classifier) { 124 if (currMinimumError < minimumError[0]) { 125 minimumError[0] = currMinimumError; 126 classifier.dimension = currClassifier.dimension; 127 classifier.sign = currClassifier.sign; 128 classifier.threshold = currClassifier.threshold; 129 } 130 } 131 } 132 }); 133 134 return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]); 135 } 136 } 137 138 /** 139 * The dimension of the feature on which this stump operates 140 */ 141 public int dimension; 142 143 /** 144 * The threshold which is the feature is tested against 145 */ 146 public float threshold; 147 148 /** 149 * The sign of the stump (determines which side of the threshold corresponds 150 * to positive) 151 */ 152 public int sign; 153 154 public boolean classify(float[] instanceFeature) { 155 return (instanceFeature[dimension] > threshold ? sign : -sign) == 1; 156 } 157 158 public boolean classify(float f) { 159 return (f > threshold ? sign : -sign) == 1; 160 } 161}