001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/machine-learning/clustering/src/main/jtemp/org/openimaj/ml/clustering/kmeans/#T#KMeans.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034 
035package org.openimaj.ml.clustering.kmeans;
036
037import java.util.ArrayList;
038import java.util.Arrays;
039import java.util.List;
040import java.util.Random;
041import java.util.concurrent.Callable;
042import java.util.concurrent.ExecutorService;
043
044import org.openimaj.data.DataSource;
045import org.openimaj.data.ByteArrayBackedDataSource;
046import org.openimaj.ml.clustering.IndexClusters;
047import org.openimaj.ml.clustering.SpatialClusterer;
048import org.openimaj.ml.clustering.assignment.HardAssigner;
049import org.openimaj.ml.clustering.assignment.hard.KDTreeByteEuclideanAssigner;
050import org.openimaj.ml.clustering.assignment.hard.ExactByteAssigner;
051import org.openimaj.ml.clustering.ByteCentroidsResult;
052import org.openimaj.knn.ByteNearestNeighbours;
053import org.openimaj.knn.ByteNearestNeighboursExact;
054import org.openimaj.knn.ByteNearestNeighboursProvider;
055import org.openimaj.knn.NearestNeighboursFactory;
056import org.openimaj.knn.approximate.ByteNearestNeighboursKDTree;
057import org.openimaj.util.pair.IntFloatPair;
058
059/**
060 * Fast, parallel implementation of the K-Means algorithm with support for
061 * bigger-than-memory data. Various flavors of K-Means are supported through the
062 * selection of different subclasses of {@link ByteNearestNeighbours}; for
063 * example, approximate K-Means can be achieved using a
064 * {@link ByteNearestNeighboursKDTree} whilst exact K-Means can be achieved
065 * using an {@link ByteNearestNeighboursExact}. The specific choice of
066 * nearest-neighbour object is controlled through the
067 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration}
068 * used to construct instances of this class. The choice of
069 * {@link ByteNearestNeighbours} affects the speed of clustering; using
070 * approximate nearest-neighbours algorithms for the K-Means can produces
071 * comparable results to the exact KMeans algorithm in much shorter time.
072 * The choice and configuration of {@link ByteNearestNeighbours} can also
073 * control the type of distance function being used in the clustering.
074 * <p>
075 * The algorithm is implemented as follows: Clustering is initiated using a
076 * {@link ByteKMeansInit} and is iterative. In each round, batches of
077 * samples are assigned to centroids in parallel. The centroid assignment is
078 * performed using the pre-configured {@link ByteNearestNeighbours} instances
079 * created from the {@link KMeansConfiguration}. Once all samples are assigned
080 * new centroids are calculated and the next round started. Data point pushing
081 * is performed using the same techniques as center point assignment.
082 * <p>
083 * This implementation is able to deal with larger-than-memory datasets by
084 * streaming the samples from disk using an appropriate {@link DataSource}. The
085 * only requirement is that there is enough memory to hold all the centroids
086 * plus working memory for the batches of samples being assigned.
087 * 
088 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
089 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
090 */
091 public class ByteKMeans implements SpatialClusterer<ByteCentroidsResult, byte[]> {
092        private static class CentroidAssignmentJob implements Callable<Boolean> {
093                private final DataSource<byte[]> ds;
094                private final int startRow;
095                private final int stopRow;
096                private final ByteNearestNeighbours nno;
097                private final float [][] centroids_accum;
098                private final int [] counts;
099
100                public CentroidAssignmentJob(DataSource<byte[]> ds, int startRow, int stopRow, ByteNearestNeighbours nno, float [][] centroids_accum, int [] counts) {
101                        this.ds = ds; 
102                        this.startRow = startRow;
103                        this.stopRow = stopRow;
104                        this.nno = nno;
105                        this.centroids_accum = centroids_accum;
106                        this.counts = counts;
107                }
108                
109                @Override
110                public Boolean call() {
111                        try {
112                                int D = nno.numDimensions();
113
114                                byte [][] points = new byte[stopRow-startRow][D]; 
115                                ds.getData(startRow, stopRow, points);
116
117                                int [] argmins = new int[points.length];
118                                float [] mins = new float[points.length];
119
120                                nno.searchNN(points, argmins, mins);
121
122                                synchronized(centroids_accum){
123                                        for (int i=0; i < points.length; ++i) {
124                                                int k = argmins[i];
125                                                for (int d=0; d < D; ++d) {
126                                                        centroids_accum[k][d] += points[i][d];
127                                                }
128                                                counts[k] += 1;
129                                        }
130                                }
131                        } catch(Exception e) {
132                                e.printStackTrace();
133                        }
134                        return true;
135                }
136        }
137        
138        /**
139         * Result object for ByteKMeans, extending ByteCentroidsResult and ByteNearestNeighboursProvider,
140         * as well as giving access to state information from the operation of the K-Means algorithm  
141         * (i.e. number of iterations, and convergence state).
142         *
143         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
144         */
145        public static class Result extends ByteCentroidsResult implements ByteNearestNeighboursProvider {
146                protected ByteNearestNeighbours nn;
147                protected int iterations;
148                protected int changedCentroidCount;
149                 
150                @Override
151                public HardAssigner<byte[], float[], IntFloatPair> defaultHardAssigner() {
152                        if (nn instanceof ByteNearestNeighboursExact)
153                                return new ExactByteAssigner(this, ((ByteNearestNeighboursExact)nn).distanceComparator());
154                
155                        return new KDTreeByteEuclideanAssigner(this);
156                }
157                
158                @Override
159                public ByteNearestNeighbours getNearestNeighbours() {
160                        return nn;
161                }
162                
163                /**
164                 * Get the number of K-Means iterations that produced this result.
165                 * @return the number of iterations
166                 */
167                public int numIterations() {
168                        return iterations;
169                }
170                
171                /**
172                 * Get the number of changed centroids in the last iteration. This is 
173                 * an indicator of convergence as over time this should reduce to 0.
174                 * @return the number of changed centroids 
175                 */
176                public int numChangedCentroids() {
177                        return changedCentroidCount;
178                }
179        }
180        
181        private ByteKMeansInit init = new ByteKMeansInit.RANDOM(); 
182        private KMeansConfiguration<ByteNearestNeighbours, byte[]> conf;
183        private Random rng = new Random();
184        
185        /**
186         * Construct the clusterer with the the given configuration.
187         * 
188         * @param conf The configuration.
189         */
190        public ByteKMeans(KMeansConfiguration<ByteNearestNeighbours, byte[]> conf) {
191                this.conf = conf;
192        }
193        
194        /**
195         * A completely default {@link ByteKMeans} used primarily as a convenience function for reading.
196         */
197        protected ByteKMeans() {
198                this(new KMeansConfiguration<ByteNearestNeighbours, byte[]>());
199        }
200        
201        /**
202         * Get the current initialisation algorithm
203         *
204         * @return the init algorithm being used
205         */
206        public ByteKMeansInit getInit() {
207                return init;
208        }
209
210        /**
211         * Set the current initialisation algorithm
212         *
213         * @param init the init algorithm to be used
214         */
215        public void setInit(ByteKMeansInit init) {
216                this.init = init;
217        }
218        
219        /**
220         * Set the seed for the internal random number generator.
221         *
222         * @param seed the random seed for init random sample selection, no seed if seed < -1
223         */
224        public void seed(long seed) {
225                if(seed < 0)
226                        this.rng = new Random();
227                else
228                        this.rng = new Random(seed);
229        }
230                
231        @Override
232        public Result cluster(byte[][] data) {
233                DataSource<byte[]> ds = new ByteArrayBackedDataSource(data, rng);
234                
235                try {
236                        Result result = cluster(ds, conf.K);
237                        result.nn = conf.factory.create(result.centroids);
238                                                
239                        return result;
240                } catch (Exception e) {
241                        throw new RuntimeException(e);
242                }
243        }
244        
245        @Override
246        public int[][] performClustering(byte[][] data) {
247                ByteCentroidsResult clusters = this.cluster(data);
248                return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters();
249        }
250        
251        /**
252         * Initiate clustering with the given data and number of clusters.
253         * Internally this method constructs the array to hold the centroids 
254         * and calls {@link #cluster(DataSource, byte [][])}.
255         *
256         * @param data data source to cluster with
257         * @param K number of clusters to find
258         * @return cluster centroids
259         */
260        protected Result cluster(DataSource<byte[]> data, int K) throws Exception {
261                int D = data.numDimensions();
262                
263                Result result = new Result();
264                result.centroids = new byte[K][D];
265        
266                init.initKMeans(data, result.centroids);
267        
268                cluster(data, result);
269
270                return result;
271        }
272        
273        /**
274         * Main clustering algorithm. A number of threads as specified are 
275         * started each containing an assignment job and a reference to
276         * the same set of ByteNearestNeighbours object (i.e. Exact or KDTree). 
277         * Each thread is added to a job pool and started in parallel. 
278         * A single accumulator is shared between all threads and locked on update.
279         * <br/>
280         * This methods expects that the initial centroids have already been set in
281         * the <code>result</code> object and as such <strong>ignores</strong> the
282         * init object. <strong>In normal operation you should call one of the other <code>cluster</code>
283         * cluster methods instead of this one.</strong> However, if you wish to resume clustering
284         * iterations from a result that you've already generated this is the method
285         * to use.
286         *
287         * @param data the data to be clustered
288         * @param result the results object to be populated
289         * @throws InterruptedException if interrupted while waiting, in
290     *         which case unfinished tasks are cancelled.
291         */
292        public void cluster(byte[][] data, Result result) throws InterruptedException {
293                DataSource<byte[]> ds = new ByteArrayBackedDataSource(data, rng);
294                
295                cluster(ds, result);
296        }
297        
298        /**
299         * Main clustering algorithm. A number of threads as specified are 
300         * started each containing an assignment job and a reference to
301         * the same set of ByteNearestNeighbours object (i.e. Exact or KDTree). 
302         * Each thread is added to a job pool and started in parallel. 
303         * A single accumulator is shared between all threads and locked on update.
304         * <br/>
305         * This methods expects that the initial centroids have already been set in
306         * the <code>result</code> object and as such <strong>ignores</strong> the
307         * init object. In normal operation you should call one of the other <code>cluster</code>
308         * cluster methods instead of this one. However, if you wish to resume clustering
309         * iterations from a result that you've already generated this is the method
310         * to use.
311         *
312         * @param data the data to be clustered
313         * @param result the results object to be populated
314         * @throws InterruptedException if interrupted while waiting, in
315     *         which case unfinished tasks are cancelled.
316         */
317        public void cluster(DataSource<byte[]> data, Result result) throws InterruptedException {
318                final byte[][] centroids = result.centroids;
319                final int K = centroids.length;
320                final int D = centroids[0].length;
321                final int N = data.size();
322                float [][] centroids_accum = new float[K][D];
323                int [] new_counts = new int[K];
324
325                ExecutorService service = conf.threadpool;
326
327                for (int i=0; i<conf.niters; i++) {
328                        result.iterations++;
329                        
330                        for (int j=0; j<K; j++) 
331                                Arrays.fill(centroids_accum[j], 0);
332                        Arrays.fill(new_counts, 0);
333
334                        ByteNearestNeighbours nno = conf.factory.create(centroids);
335                        
336                        List<CentroidAssignmentJob> jobs = new ArrayList<CentroidAssignmentJob>();
337                        for (int bl = 0; bl < N; bl += conf.blockSize) {
338                                int br = Math.min(bl + conf.blockSize, N);
339                                jobs.add(new CentroidAssignmentJob(data, bl, br, nno, centroids_accum, new_counts));
340                        }
341
342                        service.invokeAll(jobs);
343
344                        result.changedCentroidCount = 0;
345                        for (int k=0; k < K; ++k) {
346                                float ssd = 0;
347                                if (new_counts[k] == 0) {
348                                        // If there's an empty cluster we replace it with a random point.
349                                        new_counts[k] = 1;
350
351                                        byte [][] rnd = new byte[][] {centroids[k]};
352                                        data.getRandomRows(rnd);
353                                        result.changedCentroidCount++;
354                                } else {
355                                        for (int d=0; d < D; ++d) {
356                                                byte newValue = (byte)((float)roundFloat((double)centroids_accum[k][d] / (double)new_counts[k]));
357                                                
358                                                // we're going to accumulate the SSD of the old vs new centroids
359                                                // as a way of determining if this centroid has changed
360                                                float diff = newValue - centroids[k][d]; 
361                                                ssd += diff*diff;
362                                                
363                                                //update to new centroid
364                                                centroids[k][d] = newValue;
365                                        }
366                                        
367                                        if (ssd != 0)
368                                                result.changedCentroidCount++;
369                                }
370                        }
371                         
372                        if (result.changedCentroidCount == 0)
373                                break; // convergence
374                }
375        }
376        
377        protected float roundFloat(double value) { return (float) value; }
378        protected double roundDouble(double value) { return value; }
379        protected long roundLong(double value) { return (long)Math.round(value); }
380        protected int roundInt(double value) { return (int)Math.round(value); }
381        
382        @Override
383        public Result cluster(DataSource<byte[]> ds) {
384                try {
385                        Result result = cluster(ds, conf.K);
386                        result.nn = conf.factory.create(result.centroids);
387                        
388                        return result;
389                } catch (Exception e) {
390                        throw new RuntimeException(e);
391                }
392        }
393
394    /**
395         * Get the configuration
396         * 
397         * @return the configuration
398         */
399    public KMeansConfiguration<ByteNearestNeighbours, byte[]> getConfiguration() {
400        return conf;
401    }
402    
403    /**
404         * Set the configuration
405         * 
406         * @param conf
407         *            the configuration to set
408         */
409    public void setConfiguration(KMeansConfiguration<ByteNearestNeighbours, byte[]> conf) {
410        this.conf = conf;
411    }
412        
413        /**
414         * Convenience method to quickly create an exact {@link ByteKMeans}. All
415         * parameters other than the number of clusters are set
416         * at their defaults, but can be manipulated through the configuration
417         * returned by {@link #getConfiguration()}.
418         * <p>
419         * Euclidean distance is used to measure the distance between points.
420         * 
421         * @param K
422         *            the number of clusters
423         * @return a {@link ByteKMeans} instance configured for exact k-means
424         */
425        public static ByteKMeans createExact(int K) {
426                final KMeansConfiguration<ByteNearestNeighbours, byte[]> conf =
427                                new KMeansConfiguration<ByteNearestNeighbours, byte[]>(K, new ByteNearestNeighboursExact.Factory());
428
429                return new ByteKMeans(conf);
430        }
431
432        /**
433         * Convenience method to quickly create an exact {@link ByteKMeans}. All
434         * parameters other than the number of clusters and number
435         * of iterations are set at their defaults, but can be manipulated through
436         * the configuration returned by {@link #getConfiguration()}.
437         * <p>
438         * Euclidean distance is used to measure the distance between points.
439         * 
440         * @param K
441         *            the number of clusters
442         * @param niters
443         *            maximum number of iterations
444         * @return a {@link ByteKMeans} instance configured for exact k-means
445         */
446        public static ByteKMeans createExact(int K, int niters) {
447                final KMeansConfiguration<ByteNearestNeighbours, byte[]> conf =
448                                new KMeansConfiguration<ByteNearestNeighbours, byte[]>(K, new ByteNearestNeighboursExact.Factory(), niters);
449
450                return new ByteKMeans(conf);
451        }
452        
453        /**
454         * Convenience method to quickly create an approximate {@link ByteKMeans}
455         * using an ensemble of KD-Trees to perform nearest-neighbour lookup. All
456         * parameters other than the number of clusters are set
457         * at their defaults, but can be manipulated through the configuration
458         * returned by {@link #getConfiguration()}. 
459         * <p>
460         * Euclidean distance is used to measure the distance between points.
461         * 
462         * @param K
463         *            the number of clusters
464         * @return a {@link ByteKMeans} instance configured for approximate k-means 
465         *              using an ensemble of KD-Trees
466         */
467        public static ByteKMeans createKDTreeEnsemble(int K) {
468                final KMeansConfiguration<ByteNearestNeighbours, byte[]> conf =
469                                new KMeansConfiguration<ByteNearestNeighbours, byte[]>(K, new ByteNearestNeighboursKDTree.Factory());
470
471                return new ByteKMeans(conf);
472        }
473        
474        @Override
475        public String toString() {
476                return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf.getNearestNeighbourFactory().getClass().getSimpleName());
477        }
478}