001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.segmentation;
031
032import gnu.trove.map.hash.TObjectFloatHashMap;
033
034import java.util.ArrayList;
035import java.util.Collections;
036import java.util.List;
037import java.util.Set;
038
039import org.openimaj.citation.annotation.Reference;
040import org.openimaj.citation.annotation.ReferenceType;
041import org.openimaj.image.FImage;
042import org.openimaj.image.Image;
043import org.openimaj.image.MBFImage;
044import org.openimaj.image.pixel.ConnectedComponent;
045import org.openimaj.image.pixel.Pixel;
046import org.openimaj.image.processing.convolution.FGaussianConvolve;
047import org.openimaj.image.processor.SinglebandImageProcessor;
048import org.openimaj.math.graph.SimpleWeightedEdge;
049import org.openimaj.util.set.DisjointSetForest;
050
051/**
052 * Implementation of the segmentation algorithm described in:
053 * Efficient Graph-Based Image Segmentation
054 * Pedro F. Felzenszwalb and Daniel P. Huttenlocher
055 * International Journal of Computer Vision, 59(2) September 2004.
056 * 
057 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
058 * @param <I> Type of {@link Image}
059 */
060@Reference(
061                type = ReferenceType.Article,
062                author = {"Felzenszwalb, Pedro F.", "Huttenlocher, Daniel P."},
063                title = "Efficient Graph-Based Image Segmentation",
064                journal = "Int. J. Comput. Vision",
065                volume = "59",
066                number = "2",
067                month = "September",
068                year = "2004",
069                pages = {"167","181"},
070                url = "http://dx.doi.org/10.1023/B:VISI.0000022288.19776.77",
071                publisher = "Kluwer Academic Publishers"
072)
073public class FelzenszwalbHuttenlocherSegmenter<I extends Image<?,I> & SinglebandImageProcessor.Processable<Float, FImage, I>> implements Segmenter<I> {
074        protected float sigma = 0.5f;
075        protected float k = 500f / 255f;
076        protected int minSize = 50;
077
078        /**
079         * Default constructor
080         */
081        public FelzenszwalbHuttenlocherSegmenter() {}
082
083        /**
084         * Construct with the given parameters
085         * @param sigma amount of blurring
086         * @param k threshold
087         * @param minSize minimum allowed component size
088         */
089        public FelzenszwalbHuttenlocherSegmenter(float sigma, float k, int minSize) {
090                this.sigma = sigma;
091                this.k = k;
092                this.minSize = minSize;
093        }
094
095        @Override
096        public List<ConnectedComponent> segment(I image) {
097                if (((Object)image) instanceof MBFImage) {
098                        return segmentImage((MBFImage)((Object)image));
099                } else {
100                        return segmentImage(new MBFImage((FImage)((Object)image)));
101                }
102        }
103
104        private float diff(MBFImage image, Pixel p1, Pixel p2) {
105                float sum = 0;
106
107                for (FImage band : image.bands) {
108                        float d = band.pixels[p1.y][p1.x] - band.pixels[p2.y][p2.x];
109                        sum += d*d;
110                }
111
112                return (float) Math.sqrt(sum);
113        }
114
115        protected List<ConnectedComponent> segmentImage(MBFImage im) {
116                int width = im.getWidth();
117                int height = im.getHeight();
118
119                MBFImage smooth = im.process(new FGaussianConvolve(sigma));
120
121                // build graph
122                List<SimpleWeightedEdge<Pixel>> edges = new ArrayList<SimpleWeightedEdge<Pixel>>();
123                for (int y = 0; y < height; y++) {
124                        for (int x = 0; x < width; x++) {
125                                if (x < width-1) {
126                                        SimpleWeightedEdge<Pixel> p = new SimpleWeightedEdge<Pixel>();
127                                        p.from = new Pixel(x, y);
128                                        p.to = new Pixel(x+1, y);
129                                        p.weight = diff(smooth, p.from, p.to);
130                                        edges.add(p);
131                                }
132
133                                if (y < height-1) {
134                                        SimpleWeightedEdge<Pixel> p = new SimpleWeightedEdge<Pixel>();
135                                        p.from = new Pixel(x, y);
136                                        p.to = new Pixel(x, y+1);
137                                        p.weight = diff(smooth, p.from, p.to);
138                                        edges.add(p);
139                                }
140
141                                if ((x < width-1) && (y < height-1)) {
142                                        SimpleWeightedEdge<Pixel> p = new SimpleWeightedEdge<Pixel>();
143                                        p.from = new Pixel(x, y);
144                                        p.to = new Pixel(x+1, y+1);
145                                        p.weight = diff(smooth, p.from, p.to);
146                                        edges.add(p);
147                                }
148
149                                if ((x < width-1) && (y > 0)) {
150                                        SimpleWeightedEdge<Pixel> p = new SimpleWeightedEdge<Pixel>();
151                                        p.from = new Pixel(x, y);
152                                        p.to = new Pixel(x+1, y-1);
153                                        p.weight = diff(smooth, p.from, p.to);
154                                        edges.add(p);
155                                }
156                        }
157                }
158
159                // segment
160                DisjointSetForest<Pixel> u = segmentGraph(width*height, edges);
161
162
163                // post process small components
164                for (int i = 0; i < edges.size(); i++) {
165                        Pixel a = u.find(edges.get(i).from);
166                        Pixel b = u.find(edges.get(i).to);
167
168                        if ((a != b) && ((u.size(a) < minSize) || (u.size(b) < minSize)))
169                                u.union(a, b);
170                }
171
172                Set<Set<Pixel>> subsets = u.getSubsets();
173                List<ConnectedComponent> ccs = new ArrayList<ConnectedComponent>();
174                for (Set<Pixel> sp : subsets) ccs.add(new ConnectedComponent(sp));
175
176                return ccs;
177        }
178
179        protected DisjointSetForest<Pixel> segmentGraph(int numVertices, List<SimpleWeightedEdge<Pixel>> edges) { 
180                // sort edges by weight
181                Collections.sort(edges, SimpleWeightedEdge.ASCENDING_COMPARATOR);
182
183                // make a disjoint-set forest
184                DisjointSetForest<Pixel> u = new DisjointSetForest<Pixel>(numVertices);
185
186                for (SimpleWeightedEdge<Pixel> edge : edges) {
187                        u.add(edge.from);
188                        u.add(edge.to);
189                }
190
191                // init thresholds
192                TObjectFloatHashMap<Pixel> threshold = new TObjectFloatHashMap<Pixel>();
193                for (Pixel p : u) {
194                        threshold.put(p, k);
195                }
196
197                // for each edge, in non-decreasing weight order...
198                for (int i = 0; i < edges.size(); i++) {
199                        SimpleWeightedEdge<Pixel> pedge = edges.get(i);
200
201                        // components connected by this edge
202                        Pixel a = u.find(pedge.from);
203                        Pixel b = u.find(pedge.to);
204                        if (a != b) {
205                                if ((pedge.weight <= threshold.get(a)) && (pedge.weight <= threshold.get(b))) {
206                                        a = u.union(a, b);
207                                        threshold.put(a, pedge.weight + (k / u.size(a)));
208                                }
209                        }
210                }
211
212                return u;
213        }
214}