001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar;
031
032import org.openimaj.citation.annotation.Reference;
033import org.openimaj.citation.annotation.ReferenceType;
034import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
035
036/**
037 * A tree of classifier stages. In the case that the tree is degenerate and all
038 * {@link Stage}s have <code>null</code> {@link Stage#failureStage()}s, then the
039 * tree is known as a <strong>cascade</strong>.
040 * <p>
041 * The general idea is that for a given window in the image being tested
042 * (defined by an x,y position and scale), the stage tree is evaluated. If when
043 * evaluating the tree a leaf node is hit (i.e. a {@link Stage} that passes
044 * successfully, but has a <code>null</code> {@link Stage#successStage()}) then
045 * the tree is said to have passed, and indicates a potential object detection
046 * within the window. If a {@link Stage} fails to pass and has a
047 * <code>null</code> {@link Stage#failureStage()} then the tree is said to have
048 * failed, indicating the object in question was not found.
049 * <p>
050 * In order to achieve good performance, this implementation pre-computes and
051 * caches variables related to a given detection scale. This means that it is
052 * <strong>NOT safe</strong> to use a detector based on this stage
053 * implementation in a multi-threaded environment such that multiple images are
054 * being tested at a given time. It is however safe to use this implementation
055 * with a detector that multi-threads its detection across the x and y window
056 * positions for a fixed scale:
057 *
058 * <code><pre>
059 *  StageTreeClassifier cascade = ...
060 * 
061 *      for each scale {
062 *              cascade.setScale(scale);
063 * 
064 *              //the x and y search could be threaded...
065 *              for each y {
066 *                      for each x {
067 *                              cascade.matches(sat, x, y); {
068 *                      }
069 *              }
070 * }
071 * </pre></code>
072 *
073 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
074 */
075@Reference(
076                type = ReferenceType.Inproceedings,
077                author = { "Viola, P.", "Jones, M." },
078                title = "Rapid object detection using a boosted cascade of simple features",
079                year = "2001",
080                booktitle = "Computer Vision and Pattern Recognition, 2001. CVPR 2001. Proceedings of the 2001 IEEE Computer Society Conference on",
081                pages = { " I", "511 ", " I", "518 vol.1" },
082                number = "",
083                volume = "1",
084                customData = {
085                                "keywords",
086                                " AdaBoost; background regions; boosted simple feature cascade; classifiers; face detection; image processing; image representation; integral image; machine learning; object specific focus-of-attention mechanism; rapid object detection; real-time applications; statistical guarantees; visual object detection; feature extraction; image classification; image representation; learning (artificial intelligence); object detection;",
087                                "doi", "10.1109/CVPR.2001.990517",
088                                "ISSN", "1063-6919 "
089                })
090public class StageTreeClassifier {
091        /**
092         * The width of the classifier
093         */
094        int width;
095
096        /**
097         * The height of the classifier
098         */
099        int height;
100
101        /**
102         * The name of the classifier
103         */
104        String name;
105
106        /**
107         * Does the classifier contain tilted features?
108         */
109        boolean hasTiltedFeatures;
110
111        /**
112         * The root of the stage tree
113         */
114        Stage root;
115
116        // cached values for the scale being processed
117        float cachedScale; // the current scale
118        float cachedInvArea; // the inverse of the current (scaled) detection window
119        int cachedW; // the width of the current (scaled) detection window
120        int cachedH; // the height of the current (scaled) detection window
121
122        /**
123         * Construct the {@link StageTreeClassifier} with the given parameters.
124         *
125         * @param width
126         *            the width of the classifier
127         * @param height
128         *            the height of the classifier
129         * @param name
130         *            the name of the classifier
131         * @param hasTiltedFeatures
132         *            are there tilted haar-like features in the classifiers?
133         * @param root
134         *            the root of the tree of stages
135         */
136        public StageTreeClassifier(int width, int height, String name, boolean hasTiltedFeatures, Stage root) {
137                this.width = width;
138                this.height = height;
139                this.name = name;
140                this.hasTiltedFeatures = hasTiltedFeatures;
141                this.root = root;
142        }
143
144        float computeWindowVarianceNorm(SummedSqTiltAreaTable sat, int x, int y) {
145                x += Math.round(cachedScale); // shift by 1 scaled px to centre box
146                y += Math.round(cachedScale);
147
148                final float sum = sat.sum.pixels[y + cachedH][x + cachedW] + sat.sum.pixels[y][x] -
149                                sat.sum.pixels[y + cachedH][x] - sat.sum.pixels[y][x + cachedW];
150                final float sqSum = sat.sqSum.pixels[y + cachedH][x + cachedW] + sat.sqSum.pixels[y][x] -
151                                sat.sqSum.pixels[y + cachedH][x] - sat.sqSum.pixels[y][x + cachedW];
152
153                final float mean = sum * cachedInvArea;
154                float wvNorm = sqSum * cachedInvArea - mean * mean;
155                wvNorm = (float) ((wvNorm > 0) ? Math.sqrt(wvNorm) : 1);
156
157                return wvNorm;
158        }
159
160        /**
161         * Set the current detection scale. This must be called before calling
162         * {@link #classify(SummedSqTiltAreaTable, int, int)}.
163         * <p>
164         * Internally, this goes through all the stages and their individual
165         * classifiers and pre-caches information related to the current scale to
166         * avoid lots of expensive recomputation of values that don't change for a
167         * given scale.
168         *
169         * @param scale
170         *            the current scale
171         */
172        public void setScale(float scale) {
173                this.cachedScale = scale;
174
175                // following the OCV code... -2 to make a slightly smaller box within
176                // window
177                cachedW = Math.round(scale * (width - 2));
178                cachedH = Math.round(scale * (height - 2));
179                cachedInvArea = 1.0f / (cachedW * cachedH);
180
181                updateCaches(root);
182        }
183
184        /**
185         * Recursively update the caches of all the stages to reflect the current
186         * scale.
187         *
188         * @param s
189         *            the stage to update
190         */
191        private void updateCaches(Stage s) {
192                s.updateCaches(this);
193
194                if (s.successStage != null)
195                        updateCaches(s.successStage);
196                if (s.failureStage != null)
197                        updateCaches(s.failureStage);
198        }
199
200        /**
201         * Apply the classifier to the given image at the given position.
202         * Internally, this will apply each stage to the image. If all stages
203         * complete successfully a detection is indicated.
204         * <p>
205         * This method returns the number of stages passed if all stages pass; if a
206         * stage fails, then (-1 * number of successful stages) is returned. For
207         * example a value of 20 indicates the successful detection from a total of
208         * 20 stages, whilst -10 indicates an unsuccessful detection due to a
209         * failure on the 11th stage.
210         *
211         * @param sat
212         *            the summed area table(s) for the image in question. If there
213         *            are tilted features, this must include the tilted SAT.
214         * @param x
215         *            the x-ordinate of the top-left of the current window
216         * @param y
217         *            the y-ordinate of the top-left of the current window
218         * @return > 0 if a detection was made; <=0 if no detection was made. The
219         *         magnitude indicates the number of stages that passed.
220         */
221        public int classify(SummedSqTiltAreaTable sat, int x, int y) {
222                final float wvNorm = computeWindowVarianceNorm(sat, x, y);
223
224                // all stages need to match for this cascade to match
225                int matches = 0; // the number of stages that pass
226                Stage stage = root;
227                while (true) { // until success or failure
228                        if (stage.pass(sat, wvNorm, x, y)) {
229                                matches++;
230                                stage = stage.successStage;
231                                if (stage == null) {
232                                        return matches;
233                                }
234                        } else {
235                                stage = stage.failureStage;
236                                if (stage == null) {
237                                        return -matches;
238                                }
239                        }
240                }
241        }
242
243        /**
244         * Get the classifier width
245         *
246         * @return the width
247         */
248        public int getWidth() {
249                return width;
250        }
251
252        /**
253         * Get the classifier height
254         *
255         * @return the height
256         */
257        public int getHeight() {
258                return height;
259        }
260
261        /**
262         * Get the classifier name
263         *
264         * @return the name
265         */
266        public String getName() {
267                return name;
268        }
269
270        /**
271         * Does the classifier use tilted haar-like features?
272         *
273         * @return true if tilted features are used; false otherwise.
274         */
275        public boolean hasTiltedFeatures() {
276                return hasTiltedFeatures;
277        }
278
279        /**
280         * Get the root {@link Stage} of the classifier
281         *
282         * @return the root
283         */
284        public Stage getRoot() {
285                return root;
286        }
287}