001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar.training;
031
032import java.io.File;
033import java.io.FileOutputStream;
034import java.io.IOException;
035import java.io.ObjectOutputStream;
036import java.util.ArrayList;
037import java.util.List;
038
039import org.openimaj.image.FImage;
040import org.openimaj.image.ImageUtilities;
041import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
042import org.openimaj.image.objectdetection.haar.HaarFeature;
043import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
044import org.openimaj.image.objectdetection.haar.Stage;
045import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
046import org.openimaj.image.objectdetection.haar.ValueClassifier;
047import org.openimaj.io.IOUtils;
048import org.openimaj.ml.classification.StumpClassifier;
049import org.openimaj.ml.classification.boosting.AdaBoost;
050import org.openimaj.util.pair.ObjectFloatPair;
051
052public class Testing {
053        List<HaarFeature> features;
054        List<SummedSqTiltAreaTable> positive = new ArrayList<SummedSqTiltAreaTable>();
055        List<SummedSqTiltAreaTable> negative = new ArrayList<SummedSqTiltAreaTable>();
056
057        void createFeatures(int width, int height) {
058                features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE);
059
060                final float invArea = 1f / ((width - 2) * (height - 2));
061                for (final HaarFeature f : features) {
062                        f.setScale(1, invArea);
063                }
064        }
065
066        // void loadPositive(boolean tilted) throws IOException {
067        // final String base = "/Users/jsh2/Data/att_faces/s%d/%d.pgm";
068        //
069        // for (int j = 1; j <= 40; j++) {
070        // for (int i = 1; i <= 10; i++) {
071        // final File file = new File(String.format(base, j, i));
072        //
073        // FImage img = ImageUtilities.readF(file);
074        // img = img.extractCenter(50, 50);
075        // img = ResizeProcessor.resample(img, 19, 19);
076        // positive.add(new SummedSqTiltAreaTable(img, tilted));
077        // }
078        // }
079        // }
080        //
081        // void loadNegative(boolean tilted) throws IOException {
082        // final File dir = new File(
083        // "/Volumes/Raid/face_databases/haartraining/tutorial-haartraining.googlecode.com/svn/trunk/data/negatives/");
084        //
085        // for (final File f : dir.listFiles()) {
086        // if (f.getName().endsWith(".jpg")) {
087        // FImage img = ImageUtilities.readF(f);
088        //
089        // final int minwh = Math.min(img.width, img.height);
090        //
091        // img = img.extractCenter(minwh, minwh);
092        // img = ResizeProcessor.resample(img, 19, 19);
093        //
094        // negative.add(new SummedSqTiltAreaTable(img, tilted));
095        // }
096        // }
097        // }
098
099        void loadImage(File image, List<SummedSqTiltAreaTable> sats, boolean
100                        tilted) throws IOException
101        {
102                final FImage img = ImageUtilities.readF(image);
103
104                sats.add(new SummedSqTiltAreaTable(img, false));
105        }
106
107        void loadPositive(boolean tilted) throws IOException {
108                for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) {
109                        if (file.getName().endsWith(".pgm")) {
110                                loadImage(file, positive, tilted);
111                        }
112                }
113        }
114
115        void loadNegative(boolean tilted) throws IOException {
116                for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) {
117                        if (file.getName().endsWith(".pgm")) {
118                                loadImage(file, negative, tilted);
119                        }
120                }
121        }
122
123        void perform() throws IOException {
124                System.out.println("Creating feature set");
125                createFeatures(19, 19);
126
127                System.out.println("Loading positive images and computing SATs");
128                loadPositive(false);
129
130                System.out.println("Loading negative images and computing SATs");
131                loadNegative(false);
132
133                System.out.println("+ve: " + positive.size());
134                System.out.println("-ve: " + negative.size());
135                System.out.println("features: " + features.size());
136
137                System.out.println("Computing cached feature sets");
138                final CachedTrainingData data = new CachedTrainingData(positive, negative, features);
139
140                System.out.println("Starting Training");
141                final AdaBoost boost = new AdaBoost();
142                final List<ObjectFloatPair<StumpClassifier>> ensemble = boost.learn(data, 500);
143
144                System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers.");
145
146                for (float threshold = 3; threshold >= -3; threshold -= 0.25f) {
147                        System.out.println("Threshold = " + threshold);
148                        boost.printClassificationQuality(data, ensemble, threshold);
149                }
150
151                final Stage root = createStage(ensemble);
152                final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root);
153                classifier.setScale(1);
154
155                for (int i = 0; i < positive.size(); i++) {
156                        if ((classifier.classify(positive.get(i), 0, 0) == 1) != AdaBoost.classify(data.getInstanceFeature(i),
157                                        ensemble))
158                                System.out.println("ERROR");
159                }
160                for (int i = 0; i < negative.size(); i++) {
161                        if ((classifier.classify(negative.get(i), 0, 0) == 1) != AdaBoost.classify(
162                                        data.getInstanceFeature(i + positive.size()), ensemble))
163                        {
164                                System.out.println(classifier.classify(negative.get(i), 0, 0) + " " + AdaBoost.classify(
165                                                data.getInstanceFeature(i + positive.size()), ensemble));
166                                System.out.println("ERROR2");
167                        }
168                }
169
170                final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
171                IOUtils.write(classifier, oos);
172                oos.close();
173        }
174
175        /**
176         * Create a {@link Stage} from a trained ensemble.
177         *
178         * @param ensemble
179         *            the ensemble
180         * @return the stage
181         */
182        private Stage createStage(final List<ObjectFloatPair<StumpClassifier>> ensemble) {
183                final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()];
184
185                for (int i = 0; i < trees.length; i++) {
186                        final ObjectFloatPair<StumpClassifier> wc = ensemble.get(i);
187                        final StumpClassifier c = wc.first;
188                        final float alpha = wc.second;
189                        final float threshold = c.threshold;
190                        final float leftValue = c.sign > 0 ? -alpha : alpha;
191                        final HaarFeature feature = features.get(c.dimension);
192
193                        final ValueClassifier left = new ValueClassifier(leftValue);
194                        final ValueClassifier right = new ValueClassifier(-leftValue);
195
196                        trees[i] = new HaarFeatureClassifier(feature, threshold, left, right);
197                }
198
199                final Stage root = new Stage(0, trees, null, null);
200                return root;
201        }
202
203        public static void main(String[] args) throws IOException {
204                new Testing().perform();
205        }
206}