001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.kmeans;
031
032import java.lang.reflect.Array;
033import java.util.ArrayList;
034import java.util.Arrays;
035import java.util.List;
036import java.util.Random;
037import java.util.concurrent.Callable;
038import java.util.concurrent.ExecutorService;
039
040import org.openimaj.data.ArrayBackedDataSource;
041import org.openimaj.data.DataSource;
042import org.openimaj.feature.FeatureVector;
043import org.openimaj.knn.NearestNeighboursFactory;
044import org.openimaj.knn.ObjectNearestNeighbours;
045import org.openimaj.knn.ObjectNearestNeighboursExact;
046import org.openimaj.knn.ObjectNearestNeighboursProvider;
047import org.openimaj.ml.clustering.FeatureVectorCentroidsResult;
048import org.openimaj.ml.clustering.IndexClusters;
049import org.openimaj.ml.clustering.SpatialClusterer;
050import org.openimaj.ml.clustering.assignment.HardAssigner;
051import org.openimaj.ml.clustering.assignment.hard.ExactFeatureVectorAssigner;
052import org.openimaj.util.comparator.DistanceComparator;
053import org.openimaj.util.pair.IntFloatPair;
054
055import com.rits.cloning.Cloner;
056
057/**
058 * Fast, parallel implementation of the K-Means algorithm with support for
059 * bigger-than-memory data. Various flavors of K-Means are supported through the
060 * selection of different subclasses of {@link ObjectNearestNeighbours}; for
061 * example, exact K-Means can be achieved using an
062 * {@link ObjectNearestNeighboursExact}. The specific choice of
063 * nearest-neighbour object is controlled through the
064 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration}
065 * used to construct instances of this class. The choice of
066 * {@link ObjectNearestNeighbours} affects the speed of clustering; using
067 * approximate nearest-neighbours algorithms for the K-Means can produces
068 * comparable results to the exact KMeans algorithm in much shorter time. The
069 * choice and configuration of {@link ObjectNearestNeighbours} can also control
070 * the type of distance function being used in the clustering.
071 * <p>
072 * The algorithm is implemented as follows: Clustering is initiated using a
073 * {@link ByteKMeansInit} and is iterative. In each round, batches of samples
074 * are assigned to centroids in parallel. The centroid assignment is performed
075 * using the pre-configured {@link ObjectNearestNeighbours} instances created
076 * from the {@link KMeansConfiguration}. Once all samples are assigned new
077 * centroids are calculated and the next round started. Data point pushing is
078 * performed using the same techniques as center point assignment.
079 * <p>
080 * This implementation is able to deal with larger-than-memory datasets by
081 * streaming the samples from disk using an appropriate {@link DataSource}. The
082 * only requirement is that there is enough memory to hold all the centroids
083 * plus working memory for the batches of samples being assigned.
084 *
085 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
086 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
087 *
088 * @param <T>
089 *            Type of object being clustered
090 */
091public class FeatureVectorKMeans<T extends FeatureVector>
092                implements
093                        SpatialClusterer<FeatureVectorCentroidsResult<T>, T>
094{
095        private static class CentroidAssignmentJob<T extends FeatureVector> implements Callable<Boolean> {
096                private final DataSource<T> ds;
097                private final int startRow;
098                private final int stopRow;
099                private final ObjectNearestNeighbours<T> nno;
100                private final double[][] centroids_accum;
101                private final int[] counts;
102
103                public CentroidAssignmentJob(DataSource<T> ds, int startRow, int stopRow, ObjectNearestNeighbours<T> nno,
104                                double[][] centroids_accum, int[] counts)
105                {
106                        this.ds = ds;
107                        this.startRow = startRow;
108                        this.stopRow = stopRow;
109                        this.nno = nno;
110                        this.centroids_accum = centroids_accum;
111                        this.counts = counts;
112                }
113
114                @Override
115                public Boolean call() {
116                        try {
117                                final int D = ds.getData(0).length();
118
119                                final T[] points = ds.createTemporaryArray(stopRow - startRow);
120                                ds.getData(startRow, stopRow, points);
121
122                                final int[] argmins = new int[points.length];
123                                final float[] mins = new float[points.length];
124
125                                nno.searchNN(points, argmins, mins);
126
127                                synchronized (centroids_accum) {
128                                        for (int i = 0; i < points.length; ++i) {
129                                                final int k = argmins[i];
130                                                final double[] vector = points[i].asDoubleVector();
131                                                for (int d = 0; d < D; ++d) {
132                                                        centroids_accum[k][d] += vector[d];
133                                                }
134                                                counts[k] += 1;
135                                        }
136                                }
137                        } catch (final Exception e) {
138                                e.printStackTrace();
139                        }
140                        return true;
141                }
142        }
143
144        /**
145         * Result object for FeatureVectorKMeans, extending
146         * FeatureVectorCentroidsResult and ObjectNearestNeighboursProvider, as well
147         * as giving access to state information from the operation of the K-Means
148         * algorithm (i.e. number of iterations, and convergence state).
149         *
150         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
151         * @param <T>
152         *            Type of object being clustered
153         */
154        public static class Result<T extends FeatureVector> extends FeatureVectorCentroidsResult<T>
155                        implements
156                                ObjectNearestNeighboursProvider<T>
157        {
158                protected ObjectNearestNeighbours<T> nn;
159                protected int iterations;
160                protected int changedCentroidCount;
161
162                @Override
163                public ObjectNearestNeighbours<T> getNearestNeighbours() {
164                        return nn;
165                }
166
167                @Override
168                public HardAssigner<T, float[], IntFloatPair> defaultHardAssigner() {
169                        return new ExactFeatureVectorAssigner<T>(this, nn.distanceComparator());
170                }
171
172                /**
173                 * Get the number of K-Means iterations that produced this result.
174                 *
175                 * @return the number of iterations
176                 */
177                public int numIterations() {
178                        return iterations;
179                }
180
181                /**
182                 * Get the number of changed centroids in the last iteration. This is an
183                 * indicator of convergence as over time this should reduce to 0.
184                 *
185                 * @return the number of changed centroids
186                 */
187                public int numChangedCentroids() {
188                        return changedCentroidCount;
189                }
190        }
191
192        private FeatureVectorKMeansInit<T> init = new FeatureVectorKMeansInit.RANDOM<T>();
193        private KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf;
194        private Random rng = new Random();
195
196        /**
197         * Construct the clusterer with the the given configuration.
198         *
199         * @param conf
200         *            The configuration.
201         */
202        public FeatureVectorKMeans(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) {
203                this.conf = conf;
204        }
205
206        /**
207         * A completely default {@link ByteKMeans} used primarily as a convenience
208         * function for reading.
209         */
210        protected FeatureVectorKMeans() {
211                this(new KMeansConfiguration<ObjectNearestNeighbours<T>, T>());
212        }
213
214        /**
215         * Get the current initialisation algorithm
216         *
217         * @return the init algorithm being used
218         */
219        public FeatureVectorKMeansInit<T> getInit() {
220                return init;
221        }
222
223        /**
224         * Set the current initialisation algorithm
225         *
226         * @param init
227         *            the init algorithm to be used
228         */
229        public void setInit(FeatureVectorKMeansInit<T> init) {
230                this.init = init;
231        }
232
233        /**
234         * Set the seed for the internal random number generator.
235         *
236         * @param seed
237         *            the random seed for init random sample selection, no seed if
238         *            seed < -1
239         */
240        public void seed(long seed) {
241                if (seed < 0)
242                        this.rng = new Random();
243                else
244                        this.rng = new Random(seed);
245        }
246
247        /**
248         * Perform clustering on the given data.
249         *
250         * @param data
251         *            the data.
252         *
253         * @return the generated clusters.
254         */
255        public Result<T> cluster(List<T> data) {
256                @SuppressWarnings("unchecked")
257                T[] d = (T[]) Array.newInstance(data.get(0).getClass(), data.size());
258                d = data.toArray(d);
259                return cluster(d);
260        }
261
262        @Override
263        public Result<T> cluster(T[] data) {
264                final ArrayBackedDataSource<T> ds = new ArrayBackedDataSource<T>(data, rng) {
265                        @Override
266                        public int numDimensions() {
267                                return data[0].length();
268                        }
269                };
270
271                try {
272                        final Result<T> result = cluster(ds, conf.K);
273                        result.nn = conf.factory.create(result.centroids);
274
275                        return result;
276                } catch (final Exception e) {
277                        throw new RuntimeException(e);
278                }
279        }
280
281        @Override
282        public int[][] performClustering(T[] data) {
283                final FeatureVectorCentroidsResult<T> clusters = this.cluster(data);
284                return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters();
285        }
286
287        /**
288         * Perform clustering on the given data.
289         *
290         * @param data
291         *            the data.
292         *
293         * @return the generated clusters.
294         */
295        public int[][] performClustering(List<T> data) {
296                @SuppressWarnings("unchecked")
297                T[] d = (T[]) Array.newInstance(data.get(0).getClass(), data.size());
298                d = data.toArray(d);
299
300                final FeatureVectorCentroidsResult<T> clusters = this.cluster(d);
301                return new IndexClusters(clusters.defaultHardAssigner().assign(d)).clusters();
302        }
303
304        /**
305         * Initiate clustering with the given data and number of clusters.
306         * Internally this method constructs the array to hold the centroids and
307         * calls {@link #cluster(DataSource, Object)}.
308         *
309         * @param data
310         *            data source to cluster with
311         * @param K
312         *            number of clusters to find
313         * @return cluster centroids
314         */
315        protected Result<T> cluster(DataSource<T> data, int K) throws Exception {
316                final Result<T> result = new Result<T>();
317                result.centroids = data.createTemporaryArray(K);
318
319                init.initKMeans(data, result.centroids);
320
321                cluster(data, result);
322
323                return result;
324        }
325
326        /**
327         * Main clustering algorithm. A number of threads as specified are started
328         * each containing an assignment job and a reference to the same set of
329         * ObjectNearestNeighbours object (i.e. Exact or KDTree). Each thread is
330         * added to a job pool and started in parallel. A single accumulator is
331         * shared between all threads and locked on update. <br/>
332         * This methods expects that the initial centroids have already been set in
333         * the <code>result</code> object and as such <strong>ignores</strong> the
334         * init object. <strong>In normal operation you should call one of the other
335         * <code>cluster</code> cluster methods instead of this one.</strong>
336         * However, if you wish to resume clustering iterations from a result that
337         * you've already generated this is the method to use.
338         *
339         * @param data
340         *            the data to be clustered
341         * @param result
342         *            the results object to be populated
343         * @throws InterruptedException
344         *             if interrupted while waiting, in which case unfinished tasks
345         *             are cancelled.
346         */
347        public void cluster(T[] data, Result<T> result) throws InterruptedException {
348                final ArrayBackedDataSource<T> ds = new ArrayBackedDataSource<T>(data, rng) {
349                        @Override
350                        public int numDimensions() {
351                                return data[0].length();
352                        }
353                };
354
355                cluster(ds, result);
356        }
357
358        /**
359         * Main clustering algorithm. A number of threads as specified are started
360         * each containing an assignment job and a reference to the same set of
361         * ObjectNearestNeighbours object (i.e. Exact or KDTree). Each thread is
362         * added to a job pool and started in parallel. A single accumulator is
363         * shared between all threads and locked on update. <br/>
364         * This methods expects that the initial centroids have already been set in
365         * the <code>result</code> object and as such <strong>ignores</strong> the
366         * init object. In normal operation you should call one of the other
367         * <code>cluster</code> cluster methods instead of this one. However, if you
368         * wish to resume clustering iterations from a result that you've already
369         * generated this is the method to use.
370         *
371         * @param data
372         *            the data to be clustered
373         * @param result
374         *            the results object to be populated
375         * @throws InterruptedException
376         *             if interrupted while waiting, in which case unfinished tasks
377         *             are cancelled.
378         */
379        protected void cluster(DataSource<T> data, Result<T> result) throws InterruptedException {
380                final T[] centroids = result.centroids;
381                final int K = centroids.length;
382                final int D = centroids[0].length();
383                final int N = data.size();
384                final double[][] centroids_accum = new double[K][D];
385                final int[] new_counts = new int[K];
386
387                final ExecutorService service = conf.threadpool;
388
389                for (int i = 0; i < conf.niters; i++) {
390                        result.iterations++;
391
392                        for (int j = 0; j < K; j++)
393                                Arrays.fill(centroids_accum[j], 0);
394                        Arrays.fill(new_counts, 0);
395
396                        final ObjectNearestNeighbours<T> nno = conf.factory.create(centroids);
397
398                        final List<CentroidAssignmentJob<T>> jobs = new ArrayList<CentroidAssignmentJob<T>>();
399                        for (int bl = 0; bl < N; bl += conf.blockSize) {
400                                final int br = Math.min(bl + conf.blockSize, N);
401                                jobs.add(new CentroidAssignmentJob<T>(data, bl, br, nno, centroids_accum, new_counts));
402                        }
403
404                        service.invokeAll(jobs);
405
406                        result.changedCentroidCount = 0;
407                        for (int k = 0; k < K; ++k) {
408                                double ssd = 0;
409                                if (new_counts[k] == 0) {
410                                        // If there's an empty cluster we replace it with a random
411                                        // point.
412                                        new_counts[k] = 1;
413
414                                        final T[] rnd = data.createTemporaryArray(1);
415                                        data.getRandomRows(rnd);
416
417                                        final Cloner cloner = new Cloner();
418                                        centroids[k] = cloner.deepClone(rnd[0]);
419
420                                        result.changedCentroidCount++;
421                                } else {
422                                        for (int d = 0; d < D; ++d) {
423                                                final double newValue = centroids_accum[k][d] / new_counts[k];
424
425                                                // we're going to accumulate the SSD of the old vs new
426                                                // centroids
427                                                // as a way of determining if this centroid has changed
428                                                final double diff = newValue - centroids[k].getAsDouble(d);
429                                                ssd += diff * diff;
430
431                                                // update to new centroid
432                                                centroids[k].setFromDouble(d, newValue);
433                                        }
434
435                                        if (ssd != 0)
436                                                result.changedCentroidCount++;
437                                }
438                        }
439
440                        if (result.changedCentroidCount == 0)
441                                break; // convergence
442                }
443        }
444
445        @Override
446        public FeatureVectorCentroidsResult<T> cluster(DataSource<T> ds) {
447                try {
448                        final Result<T> result = cluster(ds, conf.K);
449                        result.nn = conf.factory.create(result.centroids);
450
451                        return result;
452                } catch (final Exception e) {
453                        throw new RuntimeException(e);
454                }
455        }
456
457        /**
458         * Get the configuration
459         *
460         * @return the configuration
461         */
462        public KMeansConfiguration<ObjectNearestNeighbours<T>, T> getConfiguration() {
463                return conf;
464        }
465
466        /**
467         * Set the configuration
468         *
469         * @param conf
470         *            the configuration to set
471         */
472        public void setConfiguration(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) {
473                this.conf = conf;
474        }
475
476        /**
477         * Convenience method to quickly create an exact {@link ByteKMeans}. All
478         * parameters other than the number of clusters are set at their defaults,
479         * but can be manipulated through the configuration returned by
480         * {@link #getConfiguration()}.
481         * <p>
482         * Euclidean distance is used to measure the distance between points.
483         *
484         * @param K
485         *            the number of clusters
486         * @param distance
487         *            the distance measure
488         * @return a {@link ByteKMeans} instance configured for exact k-means
489         */
490        public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K,
491                        DistanceComparator<? super T> distance)
492        {
493                final KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf = new KMeansConfiguration<ObjectNearestNeighbours<T>, T>(
494                                K, new ObjectNearestNeighboursExact.Factory<T>(
495                                                distance));
496
497                return new FeatureVectorKMeans<T>(conf);
498        }
499
500        /**
501         * Convenience method to quickly create an exact {@link ByteKMeans}. All
502         * parameters other than the number of clusters and number of iterations are
503         * set at their defaults, but can be manipulated through the configuration
504         * returned by {@link #getConfiguration()}.
505         * <p>
506         * Euclidean distance is used to measure the distance between points.
507         *
508         * @param K
509         *            the number of clusters
510         * @param distance
511         *            the distance measure
512         * @param niters
513         *            maximum number of iterations
514         * @return a {@link ByteKMeans} instance configured for exact k-means
515         */
516        public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K,
517                        DistanceComparator<? super T> distance, int niters)
518        {
519                final KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf = new KMeansConfiguration<ObjectNearestNeighbours<T>, T>(
520                                K,
521                                new ObjectNearestNeighboursExact.Factory<T>(distance),
522                                niters);
523
524                return new FeatureVectorKMeans<T>(conf);
525        }
526
527        @Override
528        public String toString() {
529                return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf
530                                .getNearestNeighbourFactory().getClass().getSimpleName());
531        }
532}