001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.hog;
031
032import java.io.File;
033import java.io.IOException;
034import java.util.AbstractList;
035import java.util.ArrayList;
036import java.util.Arrays;
037import java.util.List;
038
039import org.openimaj.data.RandomData;
040import org.openimaj.data.dataset.GroupedDataset;
041import org.openimaj.data.dataset.ListBackedDataset;
042import org.openimaj.data.dataset.ListDataset;
043import org.openimaj.data.dataset.MapBackedDataset;
044import org.openimaj.feature.DatasetExtractors;
045import org.openimaj.feature.DoubleFV;
046import org.openimaj.feature.FeatureExtractor;
047import org.openimaj.feature.IdentityFeatureExtractor;
048import org.openimaj.image.FImage;
049import org.openimaj.image.ImageUtilities;
050import org.openimaj.image.feature.dense.gradient.HOG;
051import org.openimaj.image.feature.dense.gradient.binning.FlexibleHOGStrategy;
052import org.openimaj.image.objectdetection.datasets.INRIAPersonDataset;
053import org.openimaj.image.processing.convolution.FImageGradients;
054import org.openimaj.io.IOUtils;
055import org.openimaj.math.geometry.shape.Rectangle;
056import org.openimaj.math.statistics.distribution.Histogram;
057import org.openimaj.ml.annotation.linear.LiblinearAnnotator;
058import org.openimaj.ml.annotation.linear.LiblinearAnnotator.Mode;
059import org.openimaj.util.list.AcceptingListView;
060import org.openimaj.util.list.ConcatenatedList;
061import org.openimaj.util.pair.IntObjectPair;
062
063import de.bwaldvogel.liblinear.SolverType;
064
065public class Training {
066        static class Extractor implements FeatureExtractor<DoubleFV, FImage> {
067                HOGClassifier hogClassifier;
068
069                Extractor(HOGClassifier hogClassifier) {
070                        this.hogClassifier = hogClassifier;
071                }
072
073                @Override
074                public DoubleFV extractFeature(FImage image) {
075                        final int offsetX = (image.width - 64) / 2;
076                        final int offsetY = (image.height - 128) / 2;
077                        hogClassifier.hogExtractor.analyseImage(image);
078
079                        final Histogram f = hogClassifier.hogExtractor.getFeatureVector(new Rectangle(offsetX,
080                                        offsetY, 64, 128));
081
082                        return f;
083                }
084        }
085
086        public static void main(String[] args) throws IOException {
087                final HOGClassifier hogClassifier = new HOGClassifier();
088                hogClassifier.width = 64;
089                hogClassifier.height = 128;
090
091                final FlexibleHOGStrategy strategy = new FlexibleHOGStrategy(8, 16, 2);
092                hogClassifier.hogExtractor = new HOG(9, false, FImageGradients.Mode.Unsigned, strategy);
093
094                final GroupedDataset<Boolean, ListDataset<FImage>, FImage> trainingImages = INRIAPersonDataset.getTrainingData();
095                final GroupedDataset<Boolean, ListDataset<DoubleFV>, DoubleFV> trainingData = DatasetExtractors
096                                .createLazyFeatureDataset(trainingImages, new Extractor(hogClassifier));
097
098                LiblinearAnnotator<DoubleFV, Boolean> ann = new LiblinearAnnotator<DoubleFV, Boolean>(
099                                new IdentityFeatureExtractor<DoubleFV>(), Mode.MULTICLASS, SolverType.L2R_L2LOSS_SVC, 0.01, 0.01, 1, true);
100                ann.train(trainingData);
101                hogClassifier.classifier = ann;
102
103                IOUtils.writeToFile(hogClassifier, new File("initial-classifier.dat"));
104
105                final HOGDetector detector = new HOGDetector(hogClassifier, 1.2f);
106
107                final ListDataset<FImage> negImages =
108                                INRIAPersonDataset.getNegativeTrainingImages(ImageUtilities.FIMAGE_READER);
109                final List<IntObjectPair<Rectangle>> extraNegatives = new
110                                ArrayList<IntObjectPair<Rectangle>>();
111                for (int i = 0; i < negImages.numInstances(); i++) {
112                        final FImage image = negImages.get(i);
113
114                        final List<Rectangle> rects = detector.detect(image);
115                        if (rects != null) {
116                                for (final Rectangle r : rects) {
117                                        extraNegatives.add(new IntObjectPair<Rectangle>(i, r));
118                                }
119                        }
120                }
121
122                List<FImage> hardExamples = new AbstractList<FImage>() {
123
124                        int lastImageId = -1;
125                        FImage lastImage;
126
127                        @Override
128                        public FImage get(int index) {
129                                final IntObjectPair<Rectangle> p = extraNegatives.get(index);
130
131                                if (p.first != lastImageId) {
132                                        lastImageId = p.first;
133                                        lastImage = negImages.get(p.first);
134                                }
135
136                                return lastImage.extractROI(p.second);
137                        }
138
139                        @Override
140                        public int size() {
141                                return extraNegatives.size();
142                        }
143                };
144
145                final int[] indices = RandomData.getUniqueRandomInts(2000, 0,
146                                hardExamples.size());
147                Arrays.sort(indices);
148                hardExamples = new AcceptingListView<FImage>(hardExamples, indices);
149
150                final List<FImage> extendedNegatives = new
151                                ConcatenatedList<FImage>(trainingImages.get(false), hardExamples);
152                final GroupedDataset<Boolean, ListDataset<FImage>, FImage> extendedTrainingImages = new MapBackedDataset<Boolean,
153                                ListDataset<FImage>, FImage>();
154                extendedTrainingImages.put(true, trainingImages.get(true));
155                extendedTrainingImages.put(false, new
156                                ListBackedDataset<FImage>(extendedNegatives));
157
158                final GroupedDataset<Boolean, ListDataset<DoubleFV>, DoubleFV> extendedTrainingData = DatasetExtractors
159                                .createLazyFeatureDataset(extendedTrainingImages, new
160                                                Extractor(hogClassifier));
161
162                ann = new LiblinearAnnotator<DoubleFV, Boolean>(
163                                new IdentityFeatureExtractor<DoubleFV>(), Mode.MULTICLASS,
164                                SolverType.L2R_L2LOSS_SVC, 0.01, 0.01, 1, true);
165                ann.train(extendedTrainingData);
166                hogClassifier.classifier = ann;
167
168                int c = 0, p = 0;
169                for (final FImage i : INRIAPersonDataset.getPositiveTrainingImages(ImageUtilities.FIMAGE_READER)) {
170                        hogClassifier.prepare(i);
171
172                        final int offsetX = (i.width - 64) / 2;
173                        final int offsetY = (i.height - 128) / 2;
174
175                        p += hogClassifier.classify(new Rectangle(offsetX, offsetY, 64, 128)) > 0.5 ? 1 : 0;
176                        c++;
177                }
178                System.out.println(p + "/" + c);
179
180                IOUtils.writeToFile(hogClassifier, new File("final-classifier.dat"));
181        }
182}