001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar;
031
032import java.io.IOException;
033import java.io.InputStream;
034import java.util.ArrayDeque;
035import java.util.ArrayList;
036import java.util.Arrays;
037import java.util.Deque;
038import java.util.List;
039
040import org.xmlpull.v1.XmlPullParser;
041import org.xmlpull.v1.XmlPullParserException;
042import org.xmlpull.v1.XmlPullParserFactory;
043
044/**
045 * Support for reading OpenCV Haar Cascade XML files. Currently only supports
046 * the old-style format.
047 * 
048 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
049 */
050public class OCVHaarLoader {
051        private static final float ICV_STAGE_THRESHOLD_BIAS = 0.0001f;
052
053        private static final String NEXT_NODE = "next";
054        private static final String PARENT_NODE = "parent";
055        private static final String STAGE_THRESHOLD_NODE = "stage_threshold";
056        private static final String ANONYMOUS_NODE = "_";
057        private static final String RIGHT_NODE_NODE = "right_node";
058        private static final String RIGHT_VAL_NODE = "right_val";
059        private static final String LEFT_NODE_NODE = "left_node";
060        private static final String LEFT_VAL_NODE = "left_val";
061        private static final String THRESHOLD_NODE = "threshold";
062        private static final String TILTED_NODE = "tilted";
063        private static final String RECTS_NODE = "rects";
064        private static final String FEATURE_NODE = "feature";
065        private static final String TREES_NODE = "trees";
066        private static final String STAGES_NODE = "stages";
067        private static final String SIZE_NODE = "size";
068        private static final String OCV_STORAGE_NODE = "opencv_storage";
069
070        static class TreeNode {
071                HaarFeature feature;
072                float threshold;
073                float left_val;
074                float right_val;
075                int left_node = -1;
076                int right_node = -1;
077        }
078
079        static class StageNode {
080                private int parent = -1;
081                private int next = -1;
082                private float threshold;
083                private List<List<TreeNode>> trees = new ArrayList<List<TreeNode>>();
084        }
085
086        static class OCVHaarClassifierNode {
087                int width;
088                int height;
089                String name;
090                boolean hasTiltedFeatures = false;
091                List<StageNode> stages = new ArrayList<StageNode>();
092        }
093
094        /**
095         * Read using an XML Pull Parser. This requires the exact format of the xml
096         * is consistent (i.e. element order is consistent). Checks are made at each
097         * node to ensure that we're reading the correct data.
098         * 
099         * @param in
100         *            the InputStream to consume
101         * @return the parsed cascade
102         * @throws IOException
103         */
104        static OCVHaarClassifierNode readXPP(InputStream in) throws IOException {
105                try {
106                        final XmlPullParserFactory factory = XmlPullParserFactory.newInstance();
107                        final XmlPullParser reader = factory.newPullParser();
108
109                        reader.setInput(in, null);
110
111                        reader.nextTag(); // opencv_storage
112                        checkNode(reader, OCV_STORAGE_NODE);
113
114                        reader.nextTag(); // haarcascade_{type}
115                        if (!"opencv-haar-classifier".equals(reader.getAttributeValue(null, "type_id"))) {
116                                throw new IOException("Unsupported format: " + reader.getAttributeValue(null, "type_id"));
117                        }
118
119                        final OCVHaarClassifierNode root = new OCVHaarClassifierNode();
120                        root.name = reader.getName();
121
122                        reader.nextTag(); // <size>
123                        checkNode(reader, SIZE_NODE);
124
125                        final String sizeStr = reader.nextText();
126                        final String[] widthHeight = sizeStr.trim().split(" ");
127                        if (widthHeight.length != 2) {
128                                throw new IOException("expecting 'w h' for size element, got: " + sizeStr);
129                        }
130
131                        root.width = Integer.parseInt(widthHeight[0]);
132                        root.height = Integer.parseInt(widthHeight[1]);
133
134                        reader.nextTag(); // <stages>
135                        checkNode(reader, STAGES_NODE);
136
137                        // parse stage tags
138                        while (reader.nextTag() == XmlPullParser.START_TAG) { // <_>
139                                checkNode(reader, ANONYMOUS_NODE);
140
141                                final StageNode currentStage = new StageNode();
142                                root.stages.add(currentStage);
143
144                                reader.nextTag(); // <trees>
145                                checkNode(reader, TREES_NODE);
146
147                                while (reader.nextTag() == XmlPullParser.START_TAG) { // <_>
148                                        checkNode(reader, ANONYMOUS_NODE);
149
150                                        final List<TreeNode> currentTree = new ArrayList<TreeNode>();
151                                        currentStage.trees.add(currentTree);
152
153                                        while (reader.nextTag() == XmlPullParser.START_TAG) { // <_>
154                                                checkNode(reader, ANONYMOUS_NODE);
155
156                                                final List<WeightedRectangle> regions = new ArrayList<WeightedRectangle>(3);
157
158                                                reader.nextTag(); // <feature>
159                                                checkNode(reader, FEATURE_NODE);
160
161                                                reader.nextTag(); // <rects>
162                                                checkNode(reader, RECTS_NODE);
163
164                                                while (reader.nextTag() == XmlPullParser.START_TAG) { // <_>
165                                                        checkNode(reader, ANONYMOUS_NODE);
166                                                        regions.add(WeightedRectangle.parse(reader.nextText()));
167                                                }
168
169                                                reader.nextTag(); // <tilted>
170                                                checkNode(reader, TILTED_NODE);
171                                                final boolean tilted = "1".equals(reader.nextText());
172
173                                                if (tilted)
174                                                        root.hasTiltedFeatures = true;
175
176                                                reader.nextTag(); // </feature>
177                                                checkNode(reader, FEATURE_NODE);
178
179                                                final HaarFeature currentFeature = HaarFeature.create(regions, tilted);
180
181                                                reader.nextTag(); // <threshold>
182                                                checkNode(reader, THRESHOLD_NODE);
183                                                final float threshold = (float) Double.parseDouble(reader.nextText());
184
185                                                final TreeNode treeNode = new TreeNode();
186                                                treeNode.threshold = threshold;
187                                                treeNode.feature = currentFeature;
188
189                                                reader.nextTag(); // <left_val> || <left_node>
190                                                checkNode(reader, LEFT_VAL_NODE, LEFT_NODE_NODE);
191                                                final String leftText = reader.nextText();
192                                                if ("left_val".equals(reader.getName())) {
193                                                        treeNode.left_val = Float.parseFloat(leftText);
194                                                } else {
195                                                        // find leftIndexed classifier
196                                                        treeNode.left_node = Integer.parseInt(leftText);
197                                                }
198                                                reader.nextTag(); // <right_val> || <right_node>
199                                                checkNode(reader, RIGHT_VAL_NODE, RIGHT_NODE_NODE);
200                                                final String rightText = reader.nextText();
201                                                if ("right_val".equals(reader.getName())) {
202                                                        treeNode.right_val = Float.parseFloat(rightText);
203                                                } else {
204                                                        // find right indexed classifier (put off the lookup
205                                                        // until later)
206                                                        treeNode.right_node = Integer.parseInt(rightText);
207                                                }
208
209                                                reader.nextTag(); // </_>
210                                                checkNode(reader, ANONYMOUS_NODE);
211                                                currentTree.add(treeNode);
212                                        }
213                                }
214
215                                reader.nextTag(); // <stage_threshold>
216                                checkNode(reader, STAGE_THRESHOLD_NODE);
217                                currentStage.threshold = Float.parseFloat(reader.nextText()) - ICV_STAGE_THRESHOLD_BIAS;
218
219                                reader.nextTag(); // <parent>
220                                checkNode(reader, PARENT_NODE);
221                                currentStage.parent = Integer.parseInt(reader.nextText());
222
223                                reader.nextTag(); // <next>
224                                checkNode(reader, NEXT_NODE);
225                                currentStage.next = Integer.parseInt(reader.nextText());
226
227                                reader.nextTag(); // </_>
228                                checkNode(reader, ANONYMOUS_NODE);
229                        }
230
231                        return root;
232                } catch (final XmlPullParserException ex) {
233                        throw new IOException(ex);
234                }
235        }
236
237        /**
238         * Read the cascade from an OpenCV xml serialisation. Currently this only
239         * supports the old-style cascade xml.
240         * 
241         * @param is
242         *            the stream to read from
243         * @return the cascade object
244         * @throws IOException
245         */
246        public static StageTreeClassifier read(InputStream is) throws IOException {
247                final OCVHaarClassifierNode root = readXPP(is);
248
249                return buildCascade(root);
250        }
251
252        private static StageTreeClassifier buildCascade(OCVHaarClassifierNode root) throws IOException {
253                return new StageTreeClassifier(root.width, root.height, root.name, root.hasTiltedFeatures,
254                                buildStages(root.stages));
255        }
256
257        private static Stage buildStages(List<StageNode> stageNodes) throws IOException {
258                final Stage[] stages = new Stage[stageNodes.size()];
259                for (int i = 0; i < stages.length; i++) {
260                        final StageNode node = stageNodes.get(i);
261
262                        stages[i] = new Stage(node.threshold, buildClassifiers(node.trees), null, null);
263                }
264
265                Stage root = null;
266                boolean isCascade = true;
267                for (int i = 0; i < stages.length; i++) {
268                        final StageNode node = stageNodes.get(i);
269
270                        if (node.parent == -1 && node.next == -1) {
271                                if (root == null) {
272                                        root = stages[i];
273                                } else {
274                                        throw new IOException("Inconsistent cascade/tree: multiple roots found");
275                                }
276                        }
277
278                        if (node.parent != -1) {
279                                // if it's a tree, multiple nodes might have the same parent,
280                                // but the first one we see should set the successStage
281                                if (stages[node.parent].successStage == null) {
282                                        stages[node.parent].successStage = stages[i];
283                                }
284                        }
285
286                        if (node.next != -1) {
287                                isCascade = false; // it's a tree
288                                stages[i].failureStage = stages[node.next];
289                        }
290                }
291
292                if (!isCascade) {
293                        optimiseTree(root);
294                }
295
296                return root;
297        }
298
299        /**
300         * Any failure along a success branch after a node that has a failure node
301         * should result in that failure nodes branch being executed. In order to
302         * simplify the iteration through the tree, we link all failure nodes of
303         * children of the success branch to the appropriate failure node. For
304         * example,
305         * 
306         * <pre>
307         *      3 -> 4 -> 5 -> 7 -> 9
308         *            |
309         *            | (failure) 
310         *           \ /
311         *            6 -> 8 -> 10
312         * </pre>
313         * 
314         * becomes:
315         * 
316         * <pre>
317         *      3 -> 4 -> 5 -> 7 -> 9
318         *            |    |     |
319         *            +----/-----/
320         *            |
321         *           \ /
322         *            6 -> 8 -> 10
323         * </pre>
324         * 
325         * Note: implementation based on Matt Nathan's Java port of the OpenCV Haar
326         * code.
327         * 
328         * @param root
329         *            the root of the tree
330         */
331        private static void optimiseTree(Stage root) {
332                final Deque<Stage> stack = new ArrayDeque<Stage>();
333                stack.push(root);
334
335                Stage failureStage = null;
336                while (!stack.isEmpty()) {
337                        final Stage stage = stack.pop();
338
339                        if (stage.failureStage == null) {
340                                // child of failure branch
341                                stage.failureStage = failureStage;
342
343                                if (stage.successStage != null) {
344                                        stack.push(stage.successStage);
345                                }
346                        } else if (stage.failureStage != failureStage) {
347                                // new failure branch
348                                stack.push(stage);
349
350                                failureStage = stage.failureStage;
351
352                                if (stage.successStage != null) {
353                                        stack.push(stage.successStage);
354                                }
355                        } else {
356                                // old failure branch
357                                stack.push(stage.failureStage);
358
359                                failureStage = null;
360                        }
361                }
362        }
363
364        private static Classifier[] buildClassifiers(final List<List<TreeNode>> trees) {
365                final Classifier[] classifiers = new Classifier[trees.size()];
366
367                for (int i = 0; i < classifiers.length; i++) {
368                        classifiers[i] = buildClassifier(trees.get(i));
369                }
370
371                return classifiers;
372        }
373
374        private static Classifier buildClassifier(final List<TreeNode> tree) {
375                return buildClassifier(tree, tree.get(0));
376        }
377
378        private static Classifier buildClassifier(final List<TreeNode> tree, TreeNode current) {
379                final HaarFeatureClassifier fc = new HaarFeatureClassifier(current.feature, current.threshold, null, null);
380
381                if (current.left_node == -1) {
382                        fc.left = new ValueClassifier(current.left_val);
383                } else {
384                        fc.left = buildClassifier(tree, tree.get(current.left_node));
385                }
386
387                if (current.right_node == -1) {
388                        fc.right = new ValueClassifier(current.right_val);
389                } else {
390                        fc.right = buildClassifier(tree, tree.get(current.right_node));
391                }
392
393                return fc;
394        }
395
396        private static void checkNode(XmlPullParser reader, String... expected) throws IOException {
397                for (final String e : expected)
398                        if (e.equals(reader.getName()))
399                                return;
400
401                throw new IOException("Unexpected tag: " + reader.getName() + " (expected: " + Arrays.toString(expected) + ")");
402        }
403}