001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.kmeans; 031 032import java.lang.reflect.Array; 033import java.util.ArrayList; 034import java.util.Arrays; 035import java.util.List; 036import java.util.Random; 037import java.util.concurrent.Callable; 038import java.util.concurrent.ExecutorService; 039 040import org.openimaj.data.ArrayBackedDataSource; 041import org.openimaj.data.DataSource; 042import org.openimaj.feature.FeatureVector; 043import org.openimaj.knn.NearestNeighboursFactory; 044import org.openimaj.knn.ObjectNearestNeighbours; 045import org.openimaj.knn.ObjectNearestNeighboursExact; 046import org.openimaj.knn.ObjectNearestNeighboursProvider; 047import org.openimaj.ml.clustering.FeatureVectorCentroidsResult; 048import org.openimaj.ml.clustering.IndexClusters; 049import org.openimaj.ml.clustering.SpatialClusterer; 050import org.openimaj.ml.clustering.assignment.HardAssigner; 051import org.openimaj.ml.clustering.assignment.hard.ExactFeatureVectorAssigner; 052import org.openimaj.util.comparator.DistanceComparator; 053import org.openimaj.util.pair.IntFloatPair; 054 055import com.rits.cloning.Cloner; 056 057/** 058 * Fast, parallel implementation of the K-Means algorithm with support for 059 * bigger-than-memory data. Various flavors of K-Means are supported through the 060 * selection of different subclasses of {@link ObjectNearestNeighbours}; for 061 * example, exact K-Means can be achieved using an 062 * {@link ObjectNearestNeighboursExact}. The specific choice of 063 * nearest-neighbour object is controlled through the 064 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration} 065 * used to construct instances of this class. The choice of 066 * {@link ObjectNearestNeighbours} affects the speed of clustering; using 067 * approximate nearest-neighbours algorithms for the K-Means can produces 068 * comparable results to the exact KMeans algorithm in much shorter time. The 069 * choice and configuration of {@link ObjectNearestNeighbours} can also control 070 * the type of distance function being used in the clustering. 071 * <p> 072 * The algorithm is implemented as follows: Clustering is initiated using a 073 * {@link ByteKMeansInit} and is iterative. In each round, batches of samples 074 * are assigned to centroids in parallel. The centroid assignment is performed 075 * using the pre-configured {@link ObjectNearestNeighbours} instances created 076 * from the {@link KMeansConfiguration}. Once all samples are assigned new 077 * centroids are calculated and the next round started. Data point pushing is 078 * performed using the same techniques as center point assignment. 079 * <p> 080 * This implementation is able to deal with larger-than-memory datasets by 081 * streaming the samples from disk using an appropriate {@link DataSource}. The 082 * only requirement is that there is enough memory to hold all the centroids 083 * plus working memory for the batches of samples being assigned. 084 * 085 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 086 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 087 * 088 * @param <T> 089 * Type of object being clustered 090 */ 091public class FeatureVectorKMeans<T extends FeatureVector> 092 implements 093 SpatialClusterer<FeatureVectorCentroidsResult<T>, T> 094{ 095 private static class CentroidAssignmentJob<T extends FeatureVector> implements Callable<Boolean> { 096 private final DataSource<T> ds; 097 private final int startRow; 098 private final int stopRow; 099 private final ObjectNearestNeighbours<T> nno; 100 private final double[][] centroids_accum; 101 private final int[] counts; 102 103 public CentroidAssignmentJob(DataSource<T> ds, int startRow, int stopRow, ObjectNearestNeighbours<T> nno, 104 double[][] centroids_accum, int[] counts) 105 { 106 this.ds = ds; 107 this.startRow = startRow; 108 this.stopRow = stopRow; 109 this.nno = nno; 110 this.centroids_accum = centroids_accum; 111 this.counts = counts; 112 } 113 114 @Override 115 public Boolean call() { 116 try { 117 final int D = ds.getData(0).length(); 118 119 final T[] points = ds.createTemporaryArray(stopRow - startRow); 120 ds.getData(startRow, stopRow, points); 121 122 final int[] argmins = new int[points.length]; 123 final float[] mins = new float[points.length]; 124 125 nno.searchNN(points, argmins, mins); 126 127 synchronized (centroids_accum) { 128 for (int i = 0; i < points.length; ++i) { 129 final int k = argmins[i]; 130 final double[] vector = points[i].asDoubleVector(); 131 for (int d = 0; d < D; ++d) { 132 centroids_accum[k][d] += vector[d]; 133 } 134 counts[k] += 1; 135 } 136 } 137 } catch (final Exception e) { 138 e.printStackTrace(); 139 } 140 return true; 141 } 142 } 143 144 /** 145 * Result object for FeatureVectorKMeans, extending 146 * FeatureVectorCentroidsResult and ObjectNearestNeighboursProvider, as well 147 * as giving access to state information from the operation of the K-Means 148 * algorithm (i.e. number of iterations, and convergence state). 149 * 150 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 151 * @param <T> 152 * Type of object being clustered 153 */ 154 public static class Result<T extends FeatureVector> extends FeatureVectorCentroidsResult<T> 155 implements 156 ObjectNearestNeighboursProvider<T> 157 { 158 protected ObjectNearestNeighbours<T> nn; 159 protected int iterations; 160 protected int changedCentroidCount; 161 162 @Override 163 public ObjectNearestNeighbours<T> getNearestNeighbours() { 164 return nn; 165 } 166 167 @Override 168 public HardAssigner<T, float[], IntFloatPair> defaultHardAssigner() { 169 return new ExactFeatureVectorAssigner<T>(this, nn.distanceComparator()); 170 } 171 172 /** 173 * Get the number of K-Means iterations that produced this result. 174 * 175 * @return the number of iterations 176 */ 177 public int numIterations() { 178 return iterations; 179 } 180 181 /** 182 * Get the number of changed centroids in the last iteration. This is an 183 * indicator of convergence as over time this should reduce to 0. 184 * 185 * @return the number of changed centroids 186 */ 187 public int numChangedCentroids() { 188 return changedCentroidCount; 189 } 190 } 191 192 private FeatureVectorKMeansInit<T> init = new FeatureVectorKMeansInit.RANDOM<T>(); 193 private KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf; 194 private Random rng = new Random(); 195 196 /** 197 * Construct the clusterer with the the given configuration. 198 * 199 * @param conf 200 * The configuration. 201 */ 202 public FeatureVectorKMeans(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) { 203 this.conf = conf; 204 } 205 206 /** 207 * A completely default {@link ByteKMeans} used primarily as a convenience 208 * function for reading. 209 */ 210 protected FeatureVectorKMeans() { 211 this(new KMeansConfiguration<ObjectNearestNeighbours<T>, T>()); 212 } 213 214 /** 215 * Get the current initialisation algorithm 216 * 217 * @return the init algorithm being used 218 */ 219 public FeatureVectorKMeansInit<T> getInit() { 220 return init; 221 } 222 223 /** 224 * Set the current initialisation algorithm 225 * 226 * @param init 227 * the init algorithm to be used 228 */ 229 public void setInit(FeatureVectorKMeansInit<T> init) { 230 this.init = init; 231 } 232 233 /** 234 * Set the seed for the internal random number generator. 235 * 236 * @param seed 237 * the random seed for init random sample selection, no seed if 238 * seed < -1 239 */ 240 public void seed(long seed) { 241 if (seed < 0) 242 this.rng = new Random(); 243 else 244 this.rng = new Random(seed); 245 } 246 247 /** 248 * Perform clustering on the given data. 249 * 250 * @param data 251 * the data. 252 * 253 * @return the generated clusters. 254 */ 255 public Result<T> cluster(List<T> data) { 256 @SuppressWarnings("unchecked") 257 T[] d = (T[]) Array.newInstance(data.get(0).getClass(), data.size()); 258 d = data.toArray(d); 259 return cluster(d); 260 } 261 262 @Override 263 public Result<T> cluster(T[] data) { 264 final ArrayBackedDataSource<T> ds = new ArrayBackedDataSource<T>(data, rng) { 265 @Override 266 public int numDimensions() { 267 return data[0].length(); 268 } 269 }; 270 271 try { 272 final Result<T> result = cluster(ds, conf.K); 273 result.nn = conf.factory.create(result.centroids); 274 275 return result; 276 } catch (final Exception e) { 277 throw new RuntimeException(e); 278 } 279 } 280 281 @Override 282 public int[][] performClustering(T[] data) { 283 final FeatureVectorCentroidsResult<T> clusters = this.cluster(data); 284 return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters(); 285 } 286 287 /** 288 * Perform clustering on the given data. 289 * 290 * @param data 291 * the data. 292 * 293 * @return the generated clusters. 294 */ 295 public int[][] performClustering(List<T> data) { 296 @SuppressWarnings("unchecked") 297 T[] d = (T[]) Array.newInstance(data.get(0).getClass(), data.size()); 298 d = data.toArray(d); 299 300 final FeatureVectorCentroidsResult<T> clusters = this.cluster(d); 301 return new IndexClusters(clusters.defaultHardAssigner().assign(d)).clusters(); 302 } 303 304 /** 305 * Initiate clustering with the given data and number of clusters. 306 * Internally this method constructs the array to hold the centroids and 307 * calls {@link #cluster(DataSource, Object)}. 308 * 309 * @param data 310 * data source to cluster with 311 * @param K 312 * number of clusters to find 313 * @return cluster centroids 314 */ 315 protected Result<T> cluster(DataSource<T> data, int K) throws Exception { 316 final Result<T> result = new Result<T>(); 317 result.centroids = data.createTemporaryArray(K); 318 319 init.initKMeans(data, result.centroids); 320 321 cluster(data, result); 322 323 return result; 324 } 325 326 /** 327 * Main clustering algorithm. A number of threads as specified are started 328 * each containing an assignment job and a reference to the same set of 329 * ObjectNearestNeighbours object (i.e. Exact or KDTree). Each thread is 330 * added to a job pool and started in parallel. A single accumulator is 331 * shared between all threads and locked on update. <br/> 332 * This methods expects that the initial centroids have already been set in 333 * the <code>result</code> object and as such <strong>ignores</strong> the 334 * init object. <strong>In normal operation you should call one of the other 335 * <code>cluster</code> cluster methods instead of this one.</strong> 336 * However, if you wish to resume clustering iterations from a result that 337 * you've already generated this is the method to use. 338 * 339 * @param data 340 * the data to be clustered 341 * @param result 342 * the results object to be populated 343 * @throws InterruptedException 344 * if interrupted while waiting, in which case unfinished tasks 345 * are cancelled. 346 */ 347 public void cluster(T[] data, Result<T> result) throws InterruptedException { 348 final ArrayBackedDataSource<T> ds = new ArrayBackedDataSource<T>(data, rng) { 349 @Override 350 public int numDimensions() { 351 return data[0].length(); 352 } 353 }; 354 355 cluster(ds, result); 356 } 357 358 /** 359 * Main clustering algorithm. A number of threads as specified are started 360 * each containing an assignment job and a reference to the same set of 361 * ObjectNearestNeighbours object (i.e. Exact or KDTree). Each thread is 362 * added to a job pool and started in parallel. A single accumulator is 363 * shared between all threads and locked on update. <br/> 364 * This methods expects that the initial centroids have already been set in 365 * the <code>result</code> object and as such <strong>ignores</strong> the 366 * init object. In normal operation you should call one of the other 367 * <code>cluster</code> cluster methods instead of this one. However, if you 368 * wish to resume clustering iterations from a result that you've already 369 * generated this is the method to use. 370 * 371 * @param data 372 * the data to be clustered 373 * @param result 374 * the results object to be populated 375 * @throws InterruptedException 376 * if interrupted while waiting, in which case unfinished tasks 377 * are cancelled. 378 */ 379 protected void cluster(DataSource<T> data, Result<T> result) throws InterruptedException { 380 final T[] centroids = result.centroids; 381 final int K = centroids.length; 382 final int D = centroids[0].length(); 383 final int N = data.size(); 384 final double[][] centroids_accum = new double[K][D]; 385 final int[] new_counts = new int[K]; 386 387 final ExecutorService service = conf.threadpool; 388 389 for (int i = 0; i < conf.niters; i++) { 390 result.iterations++; 391 392 for (int j = 0; j < K; j++) 393 Arrays.fill(centroids_accum[j], 0); 394 Arrays.fill(new_counts, 0); 395 396 final ObjectNearestNeighbours<T> nno = conf.factory.create(centroids); 397 398 final List<CentroidAssignmentJob<T>> jobs = new ArrayList<CentroidAssignmentJob<T>>(); 399 for (int bl = 0; bl < N; bl += conf.blockSize) { 400 final int br = Math.min(bl + conf.blockSize, N); 401 jobs.add(new CentroidAssignmentJob<T>(data, bl, br, nno, centroids_accum, new_counts)); 402 } 403 404 service.invokeAll(jobs); 405 406 result.changedCentroidCount = 0; 407 for (int k = 0; k < K; ++k) { 408 double ssd = 0; 409 if (new_counts[k] == 0) { 410 // If there's an empty cluster we replace it with a random 411 // point. 412 new_counts[k] = 1; 413 414 final T[] rnd = data.createTemporaryArray(1); 415 data.getRandomRows(rnd); 416 417 final Cloner cloner = new Cloner(); 418 centroids[k] = cloner.deepClone(rnd[0]); 419 420 result.changedCentroidCount++; 421 } else { 422 for (int d = 0; d < D; ++d) { 423 final double newValue = centroids_accum[k][d] / new_counts[k]; 424 425 // we're going to accumulate the SSD of the old vs new 426 // centroids 427 // as a way of determining if this centroid has changed 428 final double diff = newValue - centroids[k].getAsDouble(d); 429 ssd += diff * diff; 430 431 // update to new centroid 432 centroids[k].setFromDouble(d, newValue); 433 } 434 435 if (ssd != 0) 436 result.changedCentroidCount++; 437 } 438 } 439 440 if (result.changedCentroidCount == 0) 441 break; // convergence 442 } 443 } 444 445 @Override 446 public FeatureVectorCentroidsResult<T> cluster(DataSource<T> ds) { 447 try { 448 final Result<T> result = cluster(ds, conf.K); 449 result.nn = conf.factory.create(result.centroids); 450 451 return result; 452 } catch (final Exception e) { 453 throw new RuntimeException(e); 454 } 455 } 456 457 /** 458 * Get the configuration 459 * 460 * @return the configuration 461 */ 462 public KMeansConfiguration<ObjectNearestNeighbours<T>, T> getConfiguration() { 463 return conf; 464 } 465 466 /** 467 * Set the configuration 468 * 469 * @param conf 470 * the configuration to set 471 */ 472 public void setConfiguration(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) { 473 this.conf = conf; 474 } 475 476 /** 477 * Convenience method to quickly create an exact {@link ByteKMeans}. All 478 * parameters other than the number of clusters are set at their defaults, 479 * but can be manipulated through the configuration returned by 480 * {@link #getConfiguration()}. 481 * <p> 482 * Euclidean distance is used to measure the distance between points. 483 * 484 * @param K 485 * the number of clusters 486 * @param distance 487 * the distance measure 488 * @return a {@link ByteKMeans} instance configured for exact k-means 489 */ 490 public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K, 491 DistanceComparator<? super T> distance) 492 { 493 final KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf = new KMeansConfiguration<ObjectNearestNeighbours<T>, T>( 494 K, new ObjectNearestNeighboursExact.Factory<T>( 495 distance)); 496 497 return new FeatureVectorKMeans<T>(conf); 498 } 499 500 /** 501 * Convenience method to quickly create an exact {@link ByteKMeans}. All 502 * parameters other than the number of clusters and number of iterations are 503 * set at their defaults, but can be manipulated through the configuration 504 * returned by {@link #getConfiguration()}. 505 * <p> 506 * Euclidean distance is used to measure the distance between points. 507 * 508 * @param K 509 * the number of clusters 510 * @param distance 511 * the distance measure 512 * @param niters 513 * maximum number of iterations 514 * @return a {@link ByteKMeans} instance configured for exact k-means 515 */ 516 public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K, 517 DistanceComparator<? super T> distance, int niters) 518 { 519 final KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf = new KMeansConfiguration<ObjectNearestNeighbours<T>, T>( 520 K, 521 new ObjectNearestNeighboursExact.Factory<T>(distance), 522 niters); 523 524 return new FeatureVectorKMeans<T>(conf); 525 } 526 527 @Override 528 public String toString() { 529 return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf 530 .getNearestNeighbourFactory().getClass().getSimpleName()); 531 } 532}