001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jsh2/Work/openimaj/target/checkout/machine-learning/clustering/src/main/jtemp/org/openimaj/ml/clustering/kmeans/#T#KMeans.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034 035package org.openimaj.ml.clustering.kmeans; 036 037import java.util.ArrayList; 038import java.util.Arrays; 039import java.util.List; 040import java.util.Random; 041import java.util.concurrent.Callable; 042import java.util.concurrent.ExecutorService; 043 044import org.openimaj.data.DataSource; 045import org.openimaj.data.ShortArrayBackedDataSource; 046import org.openimaj.ml.clustering.IndexClusters; 047import org.openimaj.ml.clustering.SpatialClusterer; 048import org.openimaj.ml.clustering.assignment.HardAssigner; 049import org.openimaj.ml.clustering.assignment.hard.KDTreeShortEuclideanAssigner; 050import org.openimaj.ml.clustering.assignment.hard.ExactShortAssigner; 051import org.openimaj.ml.clustering.ShortCentroidsResult; 052import org.openimaj.knn.ShortNearestNeighbours; 053import org.openimaj.knn.ShortNearestNeighboursExact; 054import org.openimaj.knn.ShortNearestNeighboursProvider; 055import org.openimaj.knn.NearestNeighboursFactory; 056import org.openimaj.knn.approximate.ShortNearestNeighboursKDTree; 057import org.openimaj.util.pair.IntFloatPair; 058 059/** 060 * Fast, parallel implementation of the K-Means algorithm with support for 061 * bigger-than-memory data. Various flavors of K-Means are supported through the 062 * selection of different subclasses of {@link ShortNearestNeighbours}; for 063 * example, approximate K-Means can be achieved using a 064 * {@link ShortNearestNeighboursKDTree} whilst exact K-Means can be achieved 065 * using an {@link ShortNearestNeighboursExact}. The specific choice of 066 * nearest-neighbour object is controlled through the 067 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration} 068 * used to construct instances of this class. The choice of 069 * {@link ShortNearestNeighbours} affects the speed of clustering; using 070 * approximate nearest-neighbours algorithms for the K-Means can produces 071 * comparable results to the exact KMeans algorithm in much shorter time. 072 * The choice and configuration of {@link ShortNearestNeighbours} can also 073 * control the type of distance function being used in the clustering. 074 * <p> 075 * The algorithm is implemented as follows: Clustering is initiated using a 076 * {@link ShortKMeansInit} and is iterative. In each round, batches of 077 * samples are assigned to centroids in parallel. The centroid assignment is 078 * performed using the pre-configured {@link ShortNearestNeighbours} instances 079 * created from the {@link KMeansConfiguration}. Once all samples are assigned 080 * new centroids are calculated and the next round started. Data point pushing 081 * is performed using the same techniques as center point assignment. 082 * <p> 083 * This implementation is able to deal with larger-than-memory datasets by 084 * streaming the samples from disk using an appropriate {@link DataSource}. The 085 * only requirement is that there is enough memory to hold all the centroids 086 * plus working memory for the batches of samples being assigned. 087 * 088 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 089 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 090 */ 091 public class ShortKMeans implements SpatialClusterer<ShortCentroidsResult, short[]> { 092 private static class CentroidAssignmentJob implements Callable<Boolean> { 093 private final DataSource<short[]> ds; 094 private final int startRow; 095 private final int stopRow; 096 private final ShortNearestNeighbours nno; 097 private final float [][] centroids_accum; 098 private final int [] counts; 099 100 public CentroidAssignmentJob(DataSource<short[]> ds, int startRow, int stopRow, ShortNearestNeighbours nno, float [][] centroids_accum, int [] counts) { 101 this.ds = ds; 102 this.startRow = startRow; 103 this.stopRow = stopRow; 104 this.nno = nno; 105 this.centroids_accum = centroids_accum; 106 this.counts = counts; 107 } 108 109 @Override 110 public Boolean call() { 111 try { 112 int D = nno.numDimensions(); 113 114 short [][] points = new short[stopRow-startRow][D]; 115 ds.getData(startRow, stopRow, points); 116 117 int [] argmins = new int[points.length]; 118 float [] mins = new float[points.length]; 119 120 nno.searchNN(points, argmins, mins); 121 122 synchronized(centroids_accum){ 123 for (int i=0; i < points.length; ++i) { 124 int k = argmins[i]; 125 for (int d=0; d < D; ++d) { 126 centroids_accum[k][d] += points[i][d]; 127 } 128 counts[k] += 1; 129 } 130 } 131 } catch(Exception e) { 132 e.printStackTrace(); 133 } 134 return true; 135 } 136 } 137 138 /** 139 * Result object for ShortKMeans, extending ShortCentroidsResult and ShortNearestNeighboursProvider, 140 * as well as giving access to state information from the operation of the K-Means algorithm 141 * (i.e. number of iterations, and convergence state). 142 * 143 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 144 */ 145 public static class Result extends ShortCentroidsResult implements ShortNearestNeighboursProvider { 146 protected ShortNearestNeighbours nn; 147 protected int iterations; 148 protected int changedCentroidCount; 149 150 @Override 151 public HardAssigner<short[], float[], IntFloatPair> defaultHardAssigner() { 152 if (nn instanceof ShortNearestNeighboursExact) 153 return new ExactShortAssigner(this, ((ShortNearestNeighboursExact)nn).distanceComparator()); 154 155 return new KDTreeShortEuclideanAssigner(this); 156 } 157 158 @Override 159 public ShortNearestNeighbours getNearestNeighbours() { 160 return nn; 161 } 162 163 /** 164 * Get the number of K-Means iterations that produced this result. 165 * @return the number of iterations 166 */ 167 public int numIterations() { 168 return iterations; 169 } 170 171 /** 172 * Get the number of changed centroids in the last iteration. This is 173 * an indicator of convergence as over time this should reduce to 0. 174 * @return the number of changed centroids 175 */ 176 public int numChangedCentroids() { 177 return changedCentroidCount; 178 } 179 } 180 181 private ShortKMeansInit init = new ShortKMeansInit.RANDOM(); 182 private KMeansConfiguration<ShortNearestNeighbours, short[]> conf; 183 private Random rng = new Random(); 184 185 /** 186 * Construct the clusterer with the the given configuration. 187 * 188 * @param conf The configuration. 189 */ 190 public ShortKMeans(KMeansConfiguration<ShortNearestNeighbours, short[]> conf) { 191 this.conf = conf; 192 } 193 194 /** 195 * A completely default {@link ShortKMeans} used primarily as a convenience function for reading. 196 */ 197 protected ShortKMeans() { 198 this(new KMeansConfiguration<ShortNearestNeighbours, short[]>()); 199 } 200 201 /** 202 * Get the current initialisation algorithm 203 * 204 * @return the init algorithm being used 205 */ 206 public ShortKMeansInit getInit() { 207 return init; 208 } 209 210 /** 211 * Set the current initialisation algorithm 212 * 213 * @param init the init algorithm to be used 214 */ 215 public void setInit(ShortKMeansInit init) { 216 this.init = init; 217 } 218 219 /** 220 * Set the seed for the internal random number generator. 221 * 222 * @param seed the random seed for init random sample selection, no seed if seed < -1 223 */ 224 public void seed(long seed) { 225 if(seed < 0) 226 this.rng = new Random(); 227 else 228 this.rng = new Random(seed); 229 } 230 231 @Override 232 public Result cluster(short[][] data) { 233 DataSource<short[]> ds = new ShortArrayBackedDataSource(data, rng); 234 235 try { 236 Result result = cluster(ds, conf.K); 237 result.nn = conf.factory.create(result.centroids); 238 239 return result; 240 } catch (Exception e) { 241 throw new RuntimeException(e); 242 } 243 } 244 245 @Override 246 public int[][] performClustering(short[][] data) { 247 ShortCentroidsResult clusters = this.cluster(data); 248 return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters(); 249 } 250 251 /** 252 * Initiate clustering with the given data and number of clusters. 253 * Internally this method constructs the array to hold the centroids 254 * and calls {@link #cluster(DataSource, short [][])}. 255 * 256 * @param data data source to cluster with 257 * @param K number of clusters to find 258 * @return cluster centroids 259 */ 260 protected Result cluster(DataSource<short[]> data, int K) throws Exception { 261 int D = data.numDimensions(); 262 263 Result result = new Result(); 264 result.centroids = new short[K][D]; 265 266 init.initKMeans(data, result.centroids); 267 268 cluster(data, result); 269 270 return result; 271 } 272 273 /** 274 * Main clustering algorithm. A number of threads as specified are 275 * started each containing an assignment job and a reference to 276 * the same set of ShortNearestNeighbours object (i.e. Exact or KDTree). 277 * Each thread is added to a job pool and started in parallel. 278 * A single accumulator is shared between all threads and locked on update. 279 * <br/> 280 * This methods expects that the initial centroids have already been set in 281 * the <code>result</code> object and as such <strong>ignores</strong> the 282 * init object. <strong>In normal operation you should call one of the other <code>cluster</code> 283 * cluster methods instead of this one.</strong> However, if you wish to resume clustering 284 * iterations from a result that you've already generated this is the method 285 * to use. 286 * 287 * @param data the data to be clustered 288 * @param result the results object to be populated 289 * @throws InterruptedException if interrupted while waiting, in 290 * which case unfinished tasks are cancelled. 291 */ 292 public void cluster(short[][] data, Result result) throws InterruptedException { 293 DataSource<short[]> ds = new ShortArrayBackedDataSource(data, rng); 294 295 cluster(ds, result); 296 } 297 298 /** 299 * Main clustering algorithm. A number of threads as specified are 300 * started each containing an assignment job and a reference to 301 * the same set of ShortNearestNeighbours object (i.e. Exact or KDTree). 302 * Each thread is added to a job pool and started in parallel. 303 * A single accumulator is shared between all threads and locked on update. 304 * <br/> 305 * This methods expects that the initial centroids have already been set in 306 * the <code>result</code> object and as such <strong>ignores</strong> the 307 * init object. In normal operation you should call one of the other <code>cluster</code> 308 * cluster methods instead of this one. However, if you wish to resume clustering 309 * iterations from a result that you've already generated this is the method 310 * to use. 311 * 312 * @param data the data to be clustered 313 * @param result the results object to be populated 314 * @throws InterruptedException if interrupted while waiting, in 315 * which case unfinished tasks are cancelled. 316 */ 317 public void cluster(DataSource<short[]> data, Result result) throws InterruptedException { 318 final short[][] centroids = result.centroids; 319 final int K = centroids.length; 320 final int D = centroids[0].length; 321 final int N = data.size(); 322 float [][] centroids_accum = new float[K][D]; 323 int [] new_counts = new int[K]; 324 325 ExecutorService service = conf.threadpool; 326 327 for (int i=0; i<conf.niters; i++) { 328 result.iterations++; 329 330 for (int j=0; j<K; j++) 331 Arrays.fill(centroids_accum[j], 0); 332 Arrays.fill(new_counts, 0); 333 334 ShortNearestNeighbours nno = conf.factory.create(centroids); 335 336 List<CentroidAssignmentJob> jobs = new ArrayList<CentroidAssignmentJob>(); 337 for (int bl = 0; bl < N; bl += conf.blockSize) { 338 int br = Math.min(bl + conf.blockSize, N); 339 jobs.add(new CentroidAssignmentJob(data, bl, br, nno, centroids_accum, new_counts)); 340 } 341 342 service.invokeAll(jobs); 343 344 result.changedCentroidCount = 0; 345 for (int k=0; k < K; ++k) { 346 float ssd = 0; 347 if (new_counts[k] == 0) { 348 // If there's an empty cluster we replace it with a random point. 349 new_counts[k] = 1; 350 351 short [][] rnd = new short[][] {centroids[k]}; 352 data.getRandomRows(rnd); 353 result.changedCentroidCount++; 354 } else { 355 for (int d=0; d < D; ++d) { 356 short newValue = (short)((float)roundFloat((double)centroids_accum[k][d] / (double)new_counts[k])); 357 358 // we're going to accumulate the SSD of the old vs new centroids 359 // as a way of determining if this centroid has changed 360 float diff = newValue - centroids[k][d]; 361 ssd += diff*diff; 362 363 //update to new centroid 364 centroids[k][d] = newValue; 365 } 366 367 if (ssd != 0) 368 result.changedCentroidCount++; 369 } 370 } 371 372 if (result.changedCentroidCount == 0) 373 break; // convergence 374 } 375 } 376 377 protected float roundFloat(double value) { return (float) value; } 378 protected double roundDouble(double value) { return value; } 379 protected long roundLong(double value) { return (long)Math.round(value); } 380 protected int roundInt(double value) { return (int)Math.round(value); } 381 382 @Override 383 public Result cluster(DataSource<short[]> ds) { 384 try { 385 Result result = cluster(ds, conf.K); 386 result.nn = conf.factory.create(result.centroids); 387 388 return result; 389 } catch (Exception e) { 390 throw new RuntimeException(e); 391 } 392 } 393 394 /** 395 * Get the configuration 396 * 397 * @return the configuration 398 */ 399 public KMeansConfiguration<ShortNearestNeighbours, short[]> getConfiguration() { 400 return conf; 401 } 402 403 /** 404 * Set the configuration 405 * 406 * @param conf 407 * the configuration to set 408 */ 409 public void setConfiguration(KMeansConfiguration<ShortNearestNeighbours, short[]> conf) { 410 this.conf = conf; 411 } 412 413 /** 414 * Convenience method to quickly create an exact {@link ShortKMeans}. All 415 * parameters other than the number of clusters are set 416 * at their defaults, but can be manipulated through the configuration 417 * returned by {@link #getConfiguration()}. 418 * <p> 419 * Euclidean distance is used to measure the distance between points. 420 * 421 * @param K 422 * the number of clusters 423 * @return a {@link ShortKMeans} instance configured for exact k-means 424 */ 425 public static ShortKMeans createExact(int K) { 426 final KMeansConfiguration<ShortNearestNeighbours, short[]> conf = 427 new KMeansConfiguration<ShortNearestNeighbours, short[]>(K, new ShortNearestNeighboursExact.Factory()); 428 429 return new ShortKMeans(conf); 430 } 431 432 /** 433 * Convenience method to quickly create an exact {@link ShortKMeans}. All 434 * parameters other than the number of clusters and number 435 * of iterations are set at their defaults, but can be manipulated through 436 * the configuration returned by {@link #getConfiguration()}. 437 * <p> 438 * Euclidean distance is used to measure the distance between points. 439 * 440 * @param K 441 * the number of clusters 442 * @param niters 443 * maximum number of iterations 444 * @return a {@link ShortKMeans} instance configured for exact k-means 445 */ 446 public static ShortKMeans createExact(int K, int niters) { 447 final KMeansConfiguration<ShortNearestNeighbours, short[]> conf = 448 new KMeansConfiguration<ShortNearestNeighbours, short[]>(K, new ShortNearestNeighboursExact.Factory(), niters); 449 450 return new ShortKMeans(conf); 451 } 452 453 /** 454 * Convenience method to quickly create an approximate {@link ShortKMeans} 455 * using an ensemble of KD-Trees to perform nearest-neighbour lookup. All 456 * parameters other than the number of clusters are set 457 * at their defaults, but can be manipulated through the configuration 458 * returned by {@link #getConfiguration()}. 459 * <p> 460 * Euclidean distance is used to measure the distance between points. 461 * 462 * @param K 463 * the number of clusters 464 * @return a {@link ShortKMeans} instance configured for approximate k-means 465 * using an ensemble of KD-Trees 466 */ 467 public static ShortKMeans createKDTreeEnsemble(int K) { 468 final KMeansConfiguration<ShortNearestNeighbours, short[]> conf = 469 new KMeansConfiguration<ShortNearestNeighbours, short[]>(K, new ShortNearestNeighboursKDTree.Factory()); 470 471 return new ShortKMeans(conf); 472 } 473 474 @Override 475 public String toString() { 476 return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf.getNearestNeighbourFactory().getClass().getSimpleName()); 477 } 478}