View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.image.objectdetection.haar.training;
31  
32  import java.io.File;
33  import java.io.FileOutputStream;
34  import java.io.IOException;
35  import java.io.ObjectOutputStream;
36  import java.util.ArrayList;
37  import java.util.List;
38  
39  import org.openimaj.image.FImage;
40  import org.openimaj.image.ImageUtilities;
41  import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
42  import org.openimaj.image.objectdetection.haar.HaarFeature;
43  import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
44  import org.openimaj.image.objectdetection.haar.Stage;
45  import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
46  import org.openimaj.image.objectdetection.haar.ValueClassifier;
47  import org.openimaj.io.IOUtils;
48  import org.openimaj.ml.classification.StumpClassifier;
49  import org.openimaj.ml.classification.boosting.AdaBoost;
50  import org.openimaj.util.pair.ObjectFloatPair;
51  
52  public class Testing {
53  	List<HaarFeature> features;
54  	List<SummedSqTiltAreaTable> positive = new ArrayList<SummedSqTiltAreaTable>();
55  	List<SummedSqTiltAreaTable> negative = new ArrayList<SummedSqTiltAreaTable>();
56  
57  	void createFeatures(int width, int height) {
58  		features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE);
59  
60  		final float invArea = 1f / ((width - 2) * (height - 2));
61  		for (final HaarFeature f : features) {
62  			f.setScale(1, invArea);
63  		}
64  	}
65  
66  	// void loadPositive(boolean tilted) throws IOException {
67  	// final String base = "/Users/jsh2/Data/att_faces/s%d/%d.pgm";
68  	//
69  	// for (int j = 1; j <= 40; j++) {
70  	// for (int i = 1; i <= 10; i++) {
71  	// final File file = new File(String.format(base, j, i));
72  	//
73  	// FImage img = ImageUtilities.readF(file);
74  	// img = img.extractCenter(50, 50);
75  	// img = ResizeProcessor.resample(img, 19, 19);
76  	// positive.add(new SummedSqTiltAreaTable(img, tilted));
77  	// }
78  	// }
79  	// }
80  	//
81  	// void loadNegative(boolean tilted) throws IOException {
82  	// final File dir = new File(
83  	// "/Volumes/Raid/face_databases/haartraining/tutorial-haartraining.googlecode.com/svn/trunk/data/negatives/");
84  	//
85  	// for (final File f : dir.listFiles()) {
86  	// if (f.getName().endsWith(".jpg")) {
87  	// FImage img = ImageUtilities.readF(f);
88  	//
89  	// final int minwh = Math.min(img.width, img.height);
90  	//
91  	// img = img.extractCenter(minwh, minwh);
92  	// img = ResizeProcessor.resample(img, 19, 19);
93  	//
94  	// negative.add(new SummedSqTiltAreaTable(img, tilted));
95  	// }
96  	// }
97  	// }
98  
99  	void loadImage(File image, List<SummedSqTiltAreaTable> sats, boolean
100 			tilted) throws IOException
101 	{
102 		final FImage img = ImageUtilities.readF(image);
103 
104 		sats.add(new SummedSqTiltAreaTable(img, false));
105 	}
106 
107 	void loadPositive(boolean tilted) throws IOException {
108 		for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) {
109 			if (file.getName().endsWith(".pgm")) {
110 				loadImage(file, positive, tilted);
111 			}
112 		}
113 	}
114 
115 	void loadNegative(boolean tilted) throws IOException {
116 		for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) {
117 			if (file.getName().endsWith(".pgm")) {
118 				loadImage(file, negative, tilted);
119 			}
120 		}
121 	}
122 
123 	void perform() throws IOException {
124 		System.out.println("Creating feature set");
125 		createFeatures(19, 19);
126 
127 		System.out.println("Loading positive images and computing SATs");
128 		loadPositive(false);
129 
130 		System.out.println("Loading negative images and computing SATs");
131 		loadNegative(false);
132 
133 		System.out.println("+ve: " + positive.size());
134 		System.out.println("-ve: " + negative.size());
135 		System.out.println("features: " + features.size());
136 
137 		System.out.println("Computing cached feature sets");
138 		final CachedTrainingData data = new CachedTrainingData(positive, negative, features);
139 
140 		System.out.println("Starting Training");
141 		final AdaBoost boost = new AdaBoost();
142 		final List<ObjectFloatPair<StumpClassifier>> ensemble = boost.learn(data, 500);
143 
144 		System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers.");
145 
146 		for (float threshold = 3; threshold >= -3; threshold -= 0.25f) {
147 			System.out.println("Threshold = " + threshold);
148 			boost.printClassificationQuality(data, ensemble, threshold);
149 		}
150 
151 		final Stage root = createStage(ensemble);
152 		final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root);
153 		classifier.setScale(1);
154 
155 		for (int i = 0; i < positive.size(); i++) {
156 			if ((classifier.classify(positive.get(i), 0, 0) == 1) != AdaBoost.classify(data.getInstanceFeature(i),
157 					ensemble))
158 				System.out.println("ERROR");
159 		}
160 		for (int i = 0; i < negative.size(); i++) {
161 			if ((classifier.classify(negative.get(i), 0, 0) == 1) != AdaBoost.classify(
162 					data.getInstanceFeature(i + positive.size()), ensemble))
163 			{
164 				System.out.println(classifier.classify(negative.get(i), 0, 0) + " " + AdaBoost.classify(
165 						data.getInstanceFeature(i + positive.size()), ensemble));
166 				System.out.println("ERROR2");
167 			}
168 		}
169 
170 		final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
171 		IOUtils.write(classifier, oos);
172 		oos.close();
173 	}
174 
175 	/**
176 	 * Create a {@link Stage} from a trained ensemble.
177 	 *
178 	 * @param ensemble
179 	 *            the ensemble
180 	 * @return the stage
181 	 */
182 	private Stage createStage(final List<ObjectFloatPair<StumpClassifier>> ensemble) {
183 		final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()];
184 
185 		for (int i = 0; i < trees.length; i++) {
186 			final ObjectFloatPair<StumpClassifier> wc = ensemble.get(i);
187 			final StumpClassifier c = wc.first;
188 			final float alpha = wc.second;
189 			final float threshold = c.threshold;
190 			final float leftValue = c.sign > 0 ? -alpha : alpha;
191 			final HaarFeature feature = features.get(c.dimension);
192 
193 			final ValueClassifier left = new ValueClassifier(leftValue);
194 			final ValueClassifier right = new ValueClassifier(-leftValue);
195 
196 			trees[i] = new HaarFeatureClassifier(feature, threshold, left, right);
197 		}
198 
199 		final Stage root = new Stage(0, trees, null, null);
200 		return root;
201 	}
202 
203 	public static void main(String[] args) throws IOException {
204 		new Testing().perform();
205 	}
206 }