001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.image.objectdetection.haar.training; 031 032import java.io.File; 033import java.io.FileOutputStream; 034import java.io.IOException; 035import java.io.ObjectOutputStream; 036import java.util.ArrayList; 037import java.util.List; 038 039import org.openimaj.image.FImage; 040import org.openimaj.image.ImageUtilities; 041import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable; 042import org.openimaj.image.objectdetection.haar.HaarFeature; 043import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier; 044import org.openimaj.image.objectdetection.haar.Stage; 045import org.openimaj.image.objectdetection.haar.StageTreeClassifier; 046import org.openimaj.image.objectdetection.haar.ValueClassifier; 047import org.openimaj.io.IOUtils; 048import org.openimaj.ml.classification.StumpClassifier; 049import org.openimaj.ml.classification.boosting.AdaBoost; 050import org.openimaj.util.pair.ObjectFloatPair; 051 052public class Testing { 053 List<HaarFeature> features; 054 List<SummedSqTiltAreaTable> positive = new ArrayList<SummedSqTiltAreaTable>(); 055 List<SummedSqTiltAreaTable> negative = new ArrayList<SummedSqTiltAreaTable>(); 056 057 void createFeatures(int width, int height) { 058 features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE); 059 060 final float invArea = 1f / ((width - 2) * (height - 2)); 061 for (final HaarFeature f : features) { 062 f.setScale(1, invArea); 063 } 064 } 065 066 // void loadPositive(boolean tilted) throws IOException { 067 // final String base = "/Users/jsh2/Data/att_faces/s%d/%d.pgm"; 068 // 069 // for (int j = 1; j <= 40; j++) { 070 // for (int i = 1; i <= 10; i++) { 071 // final File file = new File(String.format(base, j, i)); 072 // 073 // FImage img = ImageUtilities.readF(file); 074 // img = img.extractCenter(50, 50); 075 // img = ResizeProcessor.resample(img, 19, 19); 076 // positive.add(new SummedSqTiltAreaTable(img, tilted)); 077 // } 078 // } 079 // } 080 // 081 // void loadNegative(boolean tilted) throws IOException { 082 // final File dir = new File( 083 // "/Volumes/Raid/face_databases/haartraining/tutorial-haartraining.googlecode.com/svn/trunk/data/negatives/"); 084 // 085 // for (final File f : dir.listFiles()) { 086 // if (f.getName().endsWith(".jpg")) { 087 // FImage img = ImageUtilities.readF(f); 088 // 089 // final int minwh = Math.min(img.width, img.height); 090 // 091 // img = img.extractCenter(minwh, minwh); 092 // img = ResizeProcessor.resample(img, 19, 19); 093 // 094 // negative.add(new SummedSqTiltAreaTable(img, tilted)); 095 // } 096 // } 097 // } 098 099 void loadImage(File image, List<SummedSqTiltAreaTable> sats, boolean 100 tilted) throws IOException 101 { 102 final FImage img = ImageUtilities.readF(image); 103 104 sats.add(new SummedSqTiltAreaTable(img, false)); 105 } 106 107 void loadPositive(boolean tilted) throws IOException { 108 for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) { 109 if (file.getName().endsWith(".pgm")) { 110 loadImage(file, positive, tilted); 111 } 112 } 113 } 114 115 void loadNegative(boolean tilted) throws IOException { 116 for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) { 117 if (file.getName().endsWith(".pgm")) { 118 loadImage(file, negative, tilted); 119 } 120 } 121 } 122 123 void perform() throws IOException { 124 System.out.println("Creating feature set"); 125 createFeatures(19, 19); 126 127 System.out.println("Loading positive images and computing SATs"); 128 loadPositive(false); 129 130 System.out.println("Loading negative images and computing SATs"); 131 loadNegative(false); 132 133 System.out.println("+ve: " + positive.size()); 134 System.out.println("-ve: " + negative.size()); 135 System.out.println("features: " + features.size()); 136 137 System.out.println("Computing cached feature sets"); 138 final CachedTrainingData data = new CachedTrainingData(positive, negative, features); 139 140 System.out.println("Starting Training"); 141 final AdaBoost boost = new AdaBoost(); 142 final List<ObjectFloatPair<StumpClassifier>> ensemble = boost.learn(data, 500); 143 144 System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers."); 145 146 for (float threshold = 3; threshold >= -3; threshold -= 0.25f) { 147 System.out.println("Threshold = " + threshold); 148 boost.printClassificationQuality(data, ensemble, threshold); 149 } 150 151 final Stage root = createStage(ensemble); 152 final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root); 153 classifier.setScale(1); 154 155 for (int i = 0; i < positive.size(); i++) { 156 if ((classifier.classify(positive.get(i), 0, 0) == 1) != AdaBoost.classify(data.getInstanceFeature(i), 157 ensemble)) 158 System.out.println("ERROR"); 159 } 160 for (int i = 0; i < negative.size(); i++) { 161 if ((classifier.classify(negative.get(i), 0, 0) == 1) != AdaBoost.classify( 162 data.getInstanceFeature(i + positive.size()), ensemble)) 163 { 164 System.out.println(classifier.classify(negative.get(i), 0, 0) + " " + AdaBoost.classify( 165 data.getInstanceFeature(i + positive.size()), ensemble)); 166 System.out.println("ERROR2"); 167 } 168 } 169 170 final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin"))); 171 IOUtils.write(classifier, oos); 172 oos.close(); 173 } 174 175 /** 176 * Create a {@link Stage} from a trained ensemble. 177 * 178 * @param ensemble 179 * the ensemble 180 * @return the stage 181 */ 182 private Stage createStage(final List<ObjectFloatPair<StumpClassifier>> ensemble) { 183 final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()]; 184 185 for (int i = 0; i < trees.length; i++) { 186 final ObjectFloatPair<StumpClassifier> wc = ensemble.get(i); 187 final StumpClassifier c = wc.first; 188 final float alpha = wc.second; 189 final float threshold = c.threshold; 190 final float leftValue = c.sign > 0 ? -alpha : alpha; 191 final HaarFeature feature = features.get(c.dimension); 192 193 final ValueClassifier left = new ValueClassifier(leftValue); 194 final ValueClassifier right = new ValueClassifier(-leftValue); 195 196 trees[i] = new HaarFeatureClassifier(feature, threshold, left, right); 197 } 198 199 final Stage root = new Stage(0, trees, null, null); 200 return root; 201 } 202 203 public static void main(String[] args) throws IOException { 204 new Testing().perform(); 205 } 206}