001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.segmentation;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.feature.FloatFVComparator;
036import org.openimaj.image.MBFImage;
037import org.openimaj.image.colour.ColourSpace;
038import org.openimaj.image.pixel.PixelSet;
039import org.openimaj.knn.FloatNearestNeighbours;
040import org.openimaj.knn.FloatNearestNeighboursExact;
041import org.openimaj.ml.clustering.FloatCentroidsResult;
042import org.openimaj.ml.clustering.assignment.HardAssigner;
043import org.openimaj.ml.clustering.kmeans.FloatKMeans;
044import org.openimaj.ml.clustering.kmeans.KMeansConfiguration;
045
046/**
047 * Simple image segmentation from grouping colours with k-means.
048 * 
049 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
050 * 
051 */
052public class KMColourSegmenter implements Segmenter<MBFImage> {
053        private static final int DEFAULT_MAX_ITERS = 100;
054        protected ColourSpace colourSpace;
055        protected float[] scaling;
056        protected FloatKMeans kmeans;
057
058        /**
059         * Construct using the given colour space and number of segments. Euclidean
060         * distance is used, and the elements of each colour band are unscaled. Up
061         * to 100 K-Means iterations will be performed.
062         * 
063         * @param colourSpace
064         *            the colour space
065         * @param K
066         *            the number of segments
067         */
068        public KMColourSegmenter(ColourSpace colourSpace, int K) {
069                this(colourSpace, null, K, null, DEFAULT_MAX_ITERS);
070        }
071
072        /**
073         * Construct using the given colour space, number of segments, and distance
074         * measure. The elements of each colour band are unscaled. Up to 100 K-Means
075         * iterations will be performed.
076         * 
077         * @param colourSpace
078         *            the colour space
079         * @param K
080         *            the number of segments
081         * @param distance
082         *            the distance measure
083         */
084        public KMColourSegmenter(ColourSpace colourSpace, int K, FloatFVComparator distance) {
085                this(colourSpace, null, K, distance, DEFAULT_MAX_ITERS);
086        }
087
088        /**
089         * Construct using the given colour space, number of segments, and distance
090         * measure. The elements of each colour band are by the corresponding
091         * elements in the given scaling vector. Up to 100 K-Means iterations will
092         * be performed.
093         * 
094         * @param colourSpace
095         *            the colour space
096         * @param scaling
097         *            the scaling vector
098         * @param K
099         *            the number of segments
100         * @param distance
101         *            the distance measure
102         */
103        public KMColourSegmenter(ColourSpace colourSpace, float[] scaling, int K, FloatFVComparator distance) {
104                this(colourSpace, scaling, K, distance, DEFAULT_MAX_ITERS);
105        }
106
107        /**
108         * Construct using the given colour space, number of segments, and distance
109         * measure. The elements of each colour band are by the corresponding
110         * elements in the given scaling vector, and the k-means algorithm will
111         * iterate at most <code>maxIters</code> times.
112         * 
113         * @param colourSpace
114         *            the colour space
115         * @param scaling
116         *            the scaling vector
117         * @param K
118         *            the number of segments
119         * @param distance
120         *            the distance measure
121         * @param maxIters
122         *            the maximum number of iterations to perform
123         */
124        public KMColourSegmenter(ColourSpace colourSpace, float[] scaling, int K, FloatFVComparator distance, int maxIters) {
125                if (scaling != null && scaling.length < colourSpace.getNumBands())
126                        throw new IllegalArgumentException(
127                                        "Scaling vector must have the same length as the number of dimensions of the target colourspace (or more)");
128
129                this.colourSpace = colourSpace;
130                this.scaling = scaling;
131
132                final KMeansConfiguration<FloatNearestNeighbours, float[]> conf =
133                                new KMeansConfiguration<FloatNearestNeighbours, float[]>(
134                                                K,
135                                                new FloatNearestNeighboursExact.Factory(distance),
136                                                maxIters);
137
138                this.kmeans = new FloatKMeans(conf);
139        }
140
141        protected float[][] imageToVector(MBFImage image) {
142                final int height = image.getHeight();
143                final int width = image.getWidth();
144                final int bands = image.numBands();
145
146                final float[][] f = new float[height * width][bands];
147                for (int b = 0; b < bands; b++) {
148                        final float[][] band = image.getBand(b).pixels;
149                        final float w = scaling == null ? 1 : scaling[b];
150
151                        for (int y = 0; y < height; y++)
152                                for (int x = 0; x < width; x++)
153                                        f[x + y * width][b] = band[y][x] * w;
154                }
155
156                return f;
157        }
158
159        @Override
160        public List<? extends PixelSet> segment(final MBFImage image) {
161                final MBFImage input = ColourSpace.convert(image, colourSpace);
162                final float[][] imageData = imageToVector(input);
163
164                final FloatCentroidsResult result = kmeans.cluster(imageData);
165
166                final List<PixelSet> out = new ArrayList<PixelSet>(kmeans.getConfiguration().getK());
167                for (int i = 0; i < kmeans.getConfiguration().getK(); i++)
168                        out.add(new PixelSet());
169
170                final HardAssigner<float[], ?, ?> assigner = result.defaultHardAssigner();
171                final int height = image.getHeight();
172                final int width = image.getWidth();
173                for (int y = 0, i = 0; y < height; y++) {
174                        for (int x = 0; x < width; x++, i++) {
175                                final float[] pixel = imageData[i];
176                                final int centroid = assigner.assign(pixel);
177
178                                out.get(centroid).addPixel(x, y);
179                        }
180                }
181
182                return out;
183        }
184}