1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.image.objectdetection.haar.training;
31
32 import java.io.File;
33 import java.io.FileOutputStream;
34 import java.io.IOException;
35 import java.io.ObjectOutputStream;
36 import java.util.ArrayList;
37 import java.util.List;
38
39 import org.openimaj.image.FImage;
40 import org.openimaj.image.ImageUtilities;
41 import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
42 import org.openimaj.image.objectdetection.haar.HaarFeature;
43 import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
44 import org.openimaj.image.objectdetection.haar.Stage;
45 import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
46 import org.openimaj.image.objectdetection.haar.ValueClassifier;
47 import org.openimaj.io.IOUtils;
48 import org.openimaj.ml.classification.StumpClassifier;
49 import org.openimaj.ml.classification.boosting.AdaBoost;
50 import org.openimaj.util.pair.ObjectFloatPair;
51
52 public class Testing {
53 List<HaarFeature> features;
54 List<SummedSqTiltAreaTable> positive = new ArrayList<SummedSqTiltAreaTable>();
55 List<SummedSqTiltAreaTable> negative = new ArrayList<SummedSqTiltAreaTable>();
56
57 void createFeatures(int width, int height) {
58 features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE);
59
60 final float invArea = 1f / ((width - 2) * (height - 2));
61 for (final HaarFeature f : features) {
62 f.setScale(1, invArea);
63 }
64 }
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 void loadImage(File image, List<SummedSqTiltAreaTable> sats, boolean
100 tilted) throws IOException
101 {
102 final FImage img = ImageUtilities.readF(image);
103
104 sats.add(new SummedSqTiltAreaTable(img, false));
105 }
106
107 void loadPositive(boolean tilted) throws IOException {
108 for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) {
109 if (file.getName().endsWith(".pgm")) {
110 loadImage(file, positive, tilted);
111 }
112 }
113 }
114
115 void loadNegative(boolean tilted) throws IOException {
116 for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) {
117 if (file.getName().endsWith(".pgm")) {
118 loadImage(file, negative, tilted);
119 }
120 }
121 }
122
123 void perform() throws IOException {
124 System.out.println("Creating feature set");
125 createFeatures(19, 19);
126
127 System.out.println("Loading positive images and computing SATs");
128 loadPositive(false);
129
130 System.out.println("Loading negative images and computing SATs");
131 loadNegative(false);
132
133 System.out.println("+ve: " + positive.size());
134 System.out.println("-ve: " + negative.size());
135 System.out.println("features: " + features.size());
136
137 System.out.println("Computing cached feature sets");
138 final CachedTrainingData data = new CachedTrainingData(positive, negative, features);
139
140 System.out.println("Starting Training");
141 final AdaBoost boost = new AdaBoost();
142 final List<ObjectFloatPair<StumpClassifier>> ensemble = boost.learn(data, 500);
143
144 System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers.");
145
146 for (float threshold = 3; threshold >= -3; threshold -= 0.25f) {
147 System.out.println("Threshold = " + threshold);
148 boost.printClassificationQuality(data, ensemble, threshold);
149 }
150
151 final Stage root = createStage(ensemble);
152 final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root);
153 classifier.setScale(1);
154
155 for (int i = 0; i < positive.size(); i++) {
156 if ((classifier.classify(positive.get(i), 0, 0) == 1) != AdaBoost.classify(data.getInstanceFeature(i),
157 ensemble))
158 System.out.println("ERROR");
159 }
160 for (int i = 0; i < negative.size(); i++) {
161 if ((classifier.classify(negative.get(i), 0, 0) == 1) != AdaBoost.classify(
162 data.getInstanceFeature(i + positive.size()), ensemble))
163 {
164 System.out.println(classifier.classify(negative.get(i), 0, 0) + " " + AdaBoost.classify(
165 data.getInstanceFeature(i + positive.size()), ensemble));
166 System.out.println("ERROR2");
167 }
168 }
169
170 final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
171 IOUtils.write(classifier, oos);
172 oos.close();
173 }
174
175
176
177
178
179
180
181
182 private Stage createStage(final List<ObjectFloatPair<StumpClassifier>> ensemble) {
183 final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()];
184
185 for (int i = 0; i < trees.length; i++) {
186 final ObjectFloatPair<StumpClassifier> wc = ensemble.get(i);
187 final StumpClassifier c = wc.first;
188 final float alpha = wc.second;
189 final float threshold = c.threshold;
190 final float leftValue = c.sign > 0 ? -alpha : alpha;
191 final HaarFeature feature = features.get(c.dimension);
192
193 final ValueClassifier left = new ValueClassifier(leftValue);
194 final ValueClassifier right = new ValueClassifier(-leftValue);
195
196 trees[i] = new HaarFeatureClassifier(feature, threshold, left, right);
197 }
198
199 final Stage root = new Stage(0, trees, null, null);
200 return root;
201 }
202
203 public static void main(String[] args) throws IOException {
204 new Testing().perform();
205 }
206 }