1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.classification;
31
32 import org.openimaj.util.function.Operation;
33 import org.openimaj.util.pair.ObjectFloatPair;
34 import org.openimaj.util.parallel.Parallel;
35 import org.openimaj.util.parallel.Parallel.IntRange;
36
37 public class StumpClassifier {
38 public static class WeightedLearner {
39
40
41
42 public ObjectFloatPair<StumpClassifier> learn(final LabelledDataProvider trainingSet, final float[] _weights) {
43 final StumpClassifier classifier = new StumpClassifier();
44
45
46 final float[] minimumError = { Float.POSITIVE_INFINITY };
47
48 final boolean[] classes = trainingSet.getClasses();
49 final int nInstances = trainingSet.numInstances();
50
51
52 float totalErrorC = 0.0f;
53 for (int i = 0; i < nInstances; i++)
54 totalErrorC += _weights[i];
55 final float totalError = totalErrorC;
56
57
58 float initialErrorC = 0.0f;
59 for (int i = 0; i < nInstances; i++)
60 initialErrorC += !classes[i] ? _weights[i] : 0.0;
61 final float initialError = initialErrorC;
62
63
64
65 Parallel.forRange(0, trainingSet.numDimensions(), 1, new Operation<IntRange>() {
66 @Override
67 public void perform(IntRange rng) {
68 final StumpClassifier currClassifier = new StumpClassifier();
69 currClassifier.dimension = -1;
70 currClassifier.threshold = Float.NaN;
71 currClassifier.sign = 0;
72
73 float currMinimumError = Float.POSITIVE_INFINITY;
74
75 for (int d = rng.start; d < rng.stop; d += rng.incr) {
76
77
78
79 final float[] data = trainingSet.getFeatureResponse(d);
80 final int[] indices = trainingSet.getSortedResponseIndices(d);
81
82
83 float currentError = initialError;
84
85
86
87 for (int i = 0; i < nInstances - 1; i++) {
88
89 final int index = indices[i];
90 if (classes[index])
91 currentError += _weights[index];
92 else
93 currentError -= _weights[index];
94
95
96 if (data[indices[i]] == data[indices[i + 1]])
97 continue;
98
99
100
101 final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f;
102
103
104 if (currentError < currMinimumError)
105 {
106
107 currMinimumError = currentError;
108 currClassifier.dimension = d;
109 currClassifier.threshold = testThreshold;
110 currClassifier.sign = +1;
111 }
112 if ((totalError - currentError) < currMinimumError)
113 {
114
115 currMinimumError = (totalError - currentError);
116 currClassifier.dimension = d;
117 currClassifier.threshold = testThreshold;
118 currClassifier.sign = -1;
119 }
120 }
121 }
122
123 synchronized (classifier) {
124 if (currMinimumError < minimumError[0]) {
125 minimumError[0] = currMinimumError;
126 classifier.dimension = currClassifier.dimension;
127 classifier.sign = currClassifier.sign;
128 classifier.threshold = currClassifier.threshold;
129 }
130 }
131 }
132 });
133
134 return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]);
135 }
136 }
137
138
139
140
141 public int dimension;
142
143
144
145
146 public float threshold;
147
148
149
150
151
152 public int sign;
153
154 public boolean classify(float[] instanceFeature) {
155 return (instanceFeature[dimension] > threshold ? sign : -sign) == 1;
156 }
157
158 public boolean classify(float f) {
159 return (f > threshold ? sign : -sign) == 1;
160 }
161 }