001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.tools.clusterquantiser; 031 032import java.io.BufferedInputStream; 033import java.io.File; 034import java.util.HashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.concurrent.ExecutorService; 038import java.util.concurrent.Executors; 039 040import org.kohsuke.args4j.CmdLineOptionsProvider; 041import org.kohsuke.args4j.Option; 042import org.kohsuke.args4j.ProxyOptionHandler; 043import org.openimaj.io.IOUtils; 044import org.openimaj.knn.ByteNearestNeighbours; 045import org.openimaj.knn.ByteNearestNeighboursExact; 046import org.openimaj.knn.IntNearestNeighbours; 047import org.openimaj.knn.IntNearestNeighboursExact; 048import org.openimaj.knn.NearestNeighboursFactory; 049import org.openimaj.knn.approximate.ByteNearestNeighboursKDTree; 050import org.openimaj.knn.approximate.IntNearestNeighboursKDTree; 051import org.openimaj.ml.clustering.ByteCentroidsResult; 052import org.openimaj.ml.clustering.IntCentroidsResult; 053import org.openimaj.ml.clustering.SpatialClusterer; 054import org.openimaj.ml.clustering.SpatialClusters; 055import org.openimaj.ml.clustering.kmeans.ByteKMeans; 056import org.openimaj.ml.clustering.kmeans.HierarchicalByteKMeans; 057import org.openimaj.ml.clustering.kmeans.HierarchicalByteKMeansResult; 058import org.openimaj.ml.clustering.kmeans.HierarchicalIntKMeans; 059import org.openimaj.ml.clustering.kmeans.HierarchicalIntKMeansResult; 060import org.openimaj.ml.clustering.kmeans.IntKMeans; 061import org.openimaj.ml.clustering.kmeans.KMeansConfiguration; 062import org.openimaj.ml.clustering.random.RandomByteClusterer; 063import org.openimaj.ml.clustering.random.RandomIntClusterer; 064import org.openimaj.ml.clustering.random.RandomSetByteClusterer; 065import org.openimaj.ml.clustering.random.RandomSetIntClusterer; 066import org.openimaj.ml.clustering.rforest.IntRandomForest; 067import org.openimaj.tools.clusterquantiser.fastkmeans.ByteKMeansInitialisers; 068import org.openimaj.tools.clusterquantiser.samplebatch.SampleBatch; 069import org.openimaj.tools.clusterquantiser.samplebatch.SampleBatchByteDataSource; 070import org.openimaj.tools.clusterquantiser.samplebatch.SampleBatchIntDataSource; 071import org.openimaj.util.array.ByteArrayConverter; 072import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory; 073 074/** 075 * Different clustering algorithms 076 * 077 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 078 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 079 */ 080public enum ClusterType implements CmdLineOptionsProvider { 081 /** 082 * Randomly sampled centroids (with replacement; the same centroid might be 083 * picked multiple times) 084 */ 085 RANDOM { 086 @Override 087 public ClusterTypeOp getOptions() { 088 return new RandomOp(); 089 } 090 }, 091 /** 092 * Randomly sampled centroids (without replacement; a centroid can only be 093 * picked once) 094 */ 095 RANDOMSET { 096 @Override 097 public ClusterTypeOp getOptions() { 098 return new RandomSetOp(); 099 } 100 }, 101 /** 102 * Fast (possibly approximate) batched K-Means 103 */ 104 FASTMBKMEANS { 105 @Override 106 public ClusterTypeOp getOptions() { 107 return new FastMBKMeansOp(); 108 } 109 }, 110 /** 111 * Fast (possibly approximate) K-Means 112 */ 113 FASTKMEANS { 114 @Override 115 public ClusterTypeOp getOptions() { 116 return new FastKMeansOp(); 117 } 118 }, 119 /** 120 * Hierarchical K-Means 121 */ 122 HKMEANS { 123 @Override 124 public ClusterTypeOp getOptions() { 125 return new HKMeansOp(); 126 } 127 }, 128 /** 129 * Random forest 130 */ 131 RFOREST { 132 @Override 133 public ClusterTypeOp getOptions() { 134 return new RForestOp(); 135 } 136 }; 137 138 /** 139 * Options for each {@link ClusterType}. 140 * 141 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 142 */ 143 public static abstract class ClusterTypeOp { 144 /** 145 * The precision of the clusters 146 */ 147 @Option( 148 name = "--precision", 149 aliases = "-p", 150 required = false, 151 usage = "Specify the cluster percision if supported") 152 public Precision precision = Precision.BYTE; 153 154 /** 155 * Create clusters from data 156 * 157 * @param data 158 * @return clusters 159 * @throws Exception 160 */ 161 public abstract SpatialClusters<?> create(byte[][] data) throws Exception; 162 163 /** 164 * Create clusters from data 165 * 166 * @param batches 167 * @return clusters 168 * @throws Exception 169 */ 170 public SpatialClusters<?> create(List<SampleBatch> batches) throws Exception { 171 return null; 172 } 173 174 /** 175 * @return options 176 */ 177 public Map<String, String> getOptionsMap() { 178 return new HashMap<String, String>(); 179 } 180 181 /** 182 * Set options 183 * 184 * @param options 185 */ 186 public void setOptionsMap(Map<String, String> options) { 187 188 } 189 190 /** 191 * @return java class representing clusters 192 */ 193 public abstract Class<? extends SpatialClusters<?>> getClusterClass(); 194 } 195 196 private static class RForestOp extends ClusterTypeOp { 197 @Option( 198 name = "--decisions", 199 aliases = "-d", 200 required = true, 201 usage = "Specify number of random decisions to be made per tree.", 202 metaVar = "NUMBER") 203 private int decisions = 32; 204 205 @Option( 206 name = "--number-of-trees", 207 aliases = "-nt", 208 required = true, 209 usage = "Specify number of random trees", 210 metaVar = "NUMBER") 211 private int ntrees = 32; 212 213 @Override 214 public SpatialClusters<int[]> create(byte[][] data) { 215 final IntRandomForest rf = new IntRandomForest(ntrees, decisions); 216 rf.cluster(ByteArrayConverter.byteToInt(data)); 217 return rf; 218 } 219 220 @Override 221 public Class<? extends SpatialClusters<?>> getClusterClass() { 222 return IntRandomForest.class; 223 } 224 } 225 226 private static class HKMeansOp extends ClusterTypeOp { 227 @Option( 228 name = "--depth", 229 aliases = "-d", 230 required = true, 231 usage = "Specify depth of tree in create mode.", 232 metaVar = "NUMBER") 233 private int depth = 6; 234 235 @Option( 236 name = "--clusters", 237 aliases = "-k", 238 required = true, 239 usage = "Specify number of clusters per level.", 240 metaVar = "NUMBER") 241 private int K = 10; 242 243 @Option( 244 name = "--enable-approximate", 245 aliases = "-ea", 246 required = false, 247 usage = "Enable the approximate k-means mode") 248 private boolean exactMode = false; 249 250 @Override 251 public SpatialClusters<?> create(byte[][] data) { 252 if (this.precision == Precision.BYTE) { 253 final KMeansConfiguration<ByteNearestNeighbours, byte[]> kmc = new KMeansConfiguration<ByteNearestNeighbours, byte[]>(); 254 255 if (exactMode) { 256 kmc.setNearestNeighbourFactory(new ByteNearestNeighboursExact.Factory()); 257 } else { 258 kmc.setNearestNeighbourFactory(new ByteNearestNeighboursKDTree.Factory()); 259 } 260 261 final HierarchicalByteKMeans tree = new HierarchicalByteKMeans(kmc, data[0].length, K, depth); 262 263 System.err.printf("Building vocabulary tree\n"); 264 return tree.cluster(data); 265 } else { 266 final KMeansConfiguration<IntNearestNeighbours, int[]> kmc = new KMeansConfiguration<IntNearestNeighbours, int[]>(); 267 268 if (exactMode) { 269 kmc.setNearestNeighbourFactory(new IntNearestNeighboursExact.Factory()); 270 } else { 271 kmc.setNearestNeighbourFactory(new IntNearestNeighboursKDTree.Factory()); 272 } 273 274 final HierarchicalIntKMeans tree = new HierarchicalIntKMeans(kmc, data[0].length, K, depth); 275 276 System.err.printf("Building vocabulary tree\n"); 277 return tree.cluster(ByteArrayConverter.byteToInt(data)); 278 } 279 } 280 281 @Override 282 public Class<? extends SpatialClusters<?>> getClusterClass() { 283 if (this.precision == Precision.BYTE) 284 return HierarchicalByteKMeansResult.class; 285 else 286 return HierarchicalIntKMeansResult.class; 287 288 } 289 } 290 291 private static class FastKMeansOp extends ClusterTypeOp { 292 @Option( 293 name = "--clusters", 294 aliases = "-k", 295 required = true, 296 usage = "Specify number of clusters per level.", 297 metaVar = "NUMBER") 298 private int K = 10; 299 300 @Option( 301 name = "--iterations", 302 aliases = "-itr", 303 required = false, 304 usage = "Specify number of iterations.", 305 metaVar = "NUMBER") 306 private int I = 30; 307 308 @Option( 309 name = "--batch-size", 310 aliases = "-b", 311 required = false, 312 usage = "Specify size of each batch for each iteration.", 313 metaVar = "NUMBER") 314 private int B = 50000; 315 316 @Option( 317 name = "--num-checks", 318 aliases = "-nc", 319 required = false, 320 usage = "Specify number of checks for each kd-tree.", 321 metaVar = "NUMBER") 322 private int NC = 768; 323 324 @Option( 325 name = "--num-trees", 326 aliases = "-nt", 327 required = false, 328 usage = "Specify number of kd-trees.", 329 metaVar = "NUMBER") 330 private int NT = 8; 331 332 @Option( 333 name = "--exact-nn", 334 aliases = "-ex", 335 required = false, 336 usage = "Specify whether to use exact nearest neighbours.", 337 metaVar = "BOOLEAN") 338 private boolean E = false; 339 340 @Option( 341 name = "--fastkmeans-threads", 342 aliases = "-jj", 343 required = false, 344 usage = "Specify the number of threads to use to train centroids.", 345 metaVar = "NUMBER") 346 private int jj = Runtime.getRuntime().availableProcessors(); 347 348 @Option( 349 name = "--cluster-random-seed", 350 aliases = "-crs", 351 required = false, 352 usage = "Specify a seed for the random data selection.", 353 metaVar = "NUMBER") 354 private long seed = -1; 355 356 @Option( 357 name = "--cluster-init", 358 aliases = "-cin", 359 required = false, 360 usage = "Specify the type of file to be read.", 361 handler = ProxyOptionHandler.class) 362 public ByteKMeansInitialisers clusterInit = ByteKMeansInitialisers.RANDOM; 363 public ByteKMeansInitialisers.Options clusterInitOp; 364 365 private KMeansConfiguration<ByteNearestNeighbours, byte[]> confByte(int ndims) { 366 NearestNeighboursFactory<? extends ByteNearestNeighbours, byte[]> assigner; 367 final ExecutorService pool = Executors.newFixedThreadPool(jj, new DaemonThreadFactory()); 368 369 if (E) { 370 assigner = new ByteNearestNeighboursExact.Factory(); 371 } else { 372 assigner = new ByteNearestNeighboursKDTree.Factory(NT, NC); 373 } 374 375 final KMeansConfiguration<ByteNearestNeighbours, byte[]> conf = new KMeansConfiguration<ByteNearestNeighbours, byte[]>( 376 K, assigner, I, B, pool); 377 378 return conf; 379 } 380 381 private KMeansConfiguration<IntNearestNeighbours, int[]> confInt(int ndims) { 382 NearestNeighboursFactory<? extends IntNearestNeighbours, int[]> assigner; 383 final ExecutorService pool = Executors.newFixedThreadPool(jj, new DaemonThreadFactory()); 384 385 if (E) { 386 assigner = new IntNearestNeighboursExact.Factory(); 387 } else { 388 assigner = new IntNearestNeighboursKDTree.Factory(NT, NC); 389 } 390 391 final KMeansConfiguration<IntNearestNeighbours, int[]> conf = new KMeansConfiguration<IntNearestNeighbours, int[]>( 392 K, assigner, I, B, pool); 393 394 return conf; 395 } 396 397 @Override 398 public SpatialClusters<?> create(List<SampleBatch> batches) throws Exception { 399 System.err.println("Constructing a FASTKMEANS cluster"); 400 SpatialClusterer<?, ?> c = null; 401 402 System.err.println("Constructing a fastkmeans worker: "); 403 if (this.precision == Precision.BYTE) { 404 final SampleBatchByteDataSource ds = new SampleBatchByteDataSource(batches); 405 ds.setSeed(seed); 406 407 c = new ByteKMeans(confByte(ds.numDimensions())); 408 ((ByteKMeans) c).seed(seed); 409 clusterInitOp.setClusterInit((ByteKMeans) c); 410 411 return ((ByteKMeans) c).cluster(ds); 412 } else { 413 final SampleBatchIntDataSource ds = new SampleBatchIntDataSource(batches); 414 ds.setSeed(seed); 415 416 c = new IntKMeans(confInt(ds.numDimensions())); 417 ((IntKMeans) c).seed(seed); 418 419 return ((IntKMeans) c).cluster(ds); 420 } 421 } 422 423 @Override 424 public SpatialClusters<?> create(byte[][] data) throws Exception { 425 SpatialClusterer<?, ?> c = null; 426 if (this.precision == Precision.BYTE) { 427 c = new ByteKMeans(confByte(data[0].length)); 428 ((ByteKMeans) c).seed(seed); 429 430 if (clusterInitOp == null) 431 clusterInitOp = clusterInit.getOptions(); 432 433 clusterInitOp.setClusterInit((ByteKMeans) c); 434 return ((ByteKMeans) c).cluster(data); 435 } else { 436 c = new IntKMeans(confInt(data[0].length)); 437 ((IntKMeans) c).seed(seed); 438 return ((IntKMeans) c).cluster(ByteArrayConverter.byteToInt(data)); 439 } 440 } 441 442 @Override 443 public Class<? extends SpatialClusters<?>> getClusterClass() { 444 if (this.precision == Precision.BYTE) 445 return ByteCentroidsResult.class; 446 else 447 return IntCentroidsResult.class; 448 } 449 } 450 451 private static class FastMBKMeansOp extends ClusterTypeOp { 452 @Option( 453 name = "--clusters", 454 aliases = "-k", 455 required = true, 456 usage = "Specify number of clusters per level.", 457 metaVar = "NUMBER") 458 private int K = 10; 459 460 @Option( 461 name = "--iterations", 462 aliases = "-itr", 463 required = false, 464 usage = "Specify number of iterations.", 465 metaVar = "NUMBER") 466 private int I = 30; 467 468 @Option( 469 name = "--mini-batch-size", 470 aliases = "-mb", 471 required = false, 472 usage = "Specify size of each mini-batch for each iteration.", 473 metaVar = "NUMBER") 474 private int M = 10000; 475 476 @Option( 477 name = "--batch-size", 478 aliases = "-b", 479 required = false, 480 usage = "Specify size of each batch for each iteration.", 481 metaVar = "NUMBER") 482 private int B = 50000; 483 484 @Option( 485 name = "--num-checks", 486 aliases = "-nc", 487 required = false, 488 usage = "Specify number of checks for each kd-tree.", 489 metaVar = "NUMBER") 490 private int NC = 768; 491 492 @Option( 493 name = "--num-trees", 494 aliases = "-nt", 495 required = false, 496 usage = "Specify number of kd-trees.", 497 metaVar = "NUMBER") 498 private int NT = 8; 499 500 @Option( 501 name = "--exact-nn", 502 aliases = "-ex", 503 required = false, 504 usage = "Specify whether to use exact nearest neighbours.", 505 metaVar = "NUMBER") 506 private boolean E = false; 507 508 @Option( 509 name = "--fastkmeans-threads", 510 aliases = "-jj", 511 required = false, 512 usage = "Specify the number of threads to use to train centroids.", 513 metaVar = "NUMBER") 514 private int jj = Runtime.getRuntime().availableProcessors(); 515 516 private KMeansConfiguration<IntNearestNeighbours, int[]> confInt(int ndims) { 517 NearestNeighboursFactory<? extends IntNearestNeighbours, int[]> assigner; 518 final ExecutorService pool = Executors.newFixedThreadPool(jj, new DaemonThreadFactory()); 519 520 if (E) { 521 assigner = new IntNearestNeighboursExact.Factory(); 522 } else { 523 assigner = new IntNearestNeighboursKDTree.Factory(NT, NC); 524 } 525 526 final KMeansConfiguration<IntNearestNeighbours, int[]> conf = new KMeansConfiguration<IntNearestNeighbours, int[]>( 527 K, assigner, I, B, pool); 528 529 return conf; 530 } 531 532 @Override 533 public SpatialClusters<int[]> create(byte[][] data) { 534 IntKMeans c = null; 535 c = new IntKMeans(confInt(data[0].length)); 536 return c.cluster(ByteArrayConverter.byteToInt(data)); 537 } 538 539 @Override 540 public Class<? extends SpatialClusters<?>> getClusterClass() { 541 return IntCentroidsResult.class; 542 } 543 } 544 545 private static class RandomSetOp extends ClusterTypeOp { 546 @Option( 547 name = "--clusters", 548 aliases = "-k", 549 required = false, 550 usage = "Specify number of clusters per level.", 551 metaVar = "NUMBER") 552 private int K = -1; 553 554 @Option( 555 name = "--cluster-random-seed", 556 aliases = "-crs", 557 required = false, 558 usage = "Specify a seed for the random data selection.", 559 metaVar = "NUMBER") 560 private int seed = -1; 561 562 @Override 563 public SpatialClusters<?> create(byte[][] data) { 564 if (this.precision == Precision.BYTE) { 565 RandomSetByteClusterer c = null; 566 c = new RandomSetByteClusterer(data[0].length, K); 567 if (seed >= 0) 568 c.setSeed(seed); 569 570 System.err.printf("Building BYTE vocabulary tree\n"); 571 return c.cluster(data); 572 } else { 573 RandomSetIntClusterer c = null; 574 c = new RandomSetIntClusterer(data[0].length, K); 575 if (seed >= 0) 576 c.setSeed(seed); 577 578 System.err.printf("Building INT vocabulary tree\n"); 579 return c.cluster(ByteArrayConverter.byteToInt(data)); 580 } 581 } 582 583 @Override 584 public SpatialClusters<?> create(List<SampleBatch> batches) throws Exception { 585 586 if (this.precision == Precision.BYTE) { 587 final SampleBatchByteDataSource ds = new SampleBatchByteDataSource(batches); 588 ds.setSeed(seed); 589 590 RandomSetByteClusterer c = null; 591 c = new RandomSetByteClusterer(ds.numDimensions(), K); 592 593 return c.cluster(ds); 594 } else { 595 final SampleBatchIntDataSource ds = new SampleBatchIntDataSource(batches); 596 ds.setSeed(seed); 597 598 RandomSetIntClusterer c = null; 599 c = new RandomSetIntClusterer(ds.numDimensions(), K); 600 601 return c.cluster(ds); 602 } 603 } 604 605 @Override 606 public Class<? extends SpatialClusters<?>> getClusterClass() { 607 if (this.precision == Precision.BYTE) 608 return ByteCentroidsResult.class; 609 else 610 return IntCentroidsResult.class; 611 } 612 } 613 614 private static class RandomOp extends ClusterTypeOp { 615 @Option( 616 name = "--clusters", 617 aliases = "-k", 618 required = false, 619 usage = "Specify number of clusters per level.", 620 metaVar = "NUMBER") 621 private int K = -1; 622 623 @Option( 624 name = "--cluster-random-seed", 625 aliases = "-crs", 626 required = false, 627 usage = "Specify a seed for the random data selection.", 628 metaVar = "NUMBER") 629 private int seed = -1; 630 631 @Override 632 public SpatialClusters<?> create(byte[][] data) { 633 if (this.precision == Precision.BYTE) { 634 RandomByteClusterer c = null; 635 c = new RandomByteClusterer(data[0].length, K); 636 if (seed >= 0) 637 c.setSeed(seed); 638 639 System.err.printf("Building BYTE vocabulary tree\n"); 640 return c.cluster(data); 641 } else { 642 RandomIntClusterer c = null; 643 c = new RandomIntClusterer(data[0].length, K); 644 if (seed >= 0) 645 c.setSeed(seed); 646 647 System.err.printf("Building INT vocabulary tree\n"); 648 return c.cluster(ByteArrayConverter.byteToInt(data)); 649 } 650 651 } 652 653 @Override 654 public Class<? extends SpatialClusters<?>> getClusterClass() { 655 if (this.precision == Precision.BYTE) 656 return ByteCentroidsResult.class; 657 else 658 return IntCentroidsResult.class; 659 } 660 } 661 662 /** 663 * Guess the type of the clusters based on the file header 664 * 665 * @param oldout 666 * @return guessed type 667 */ 668 public static ClusterTypeOp sniffClusterType(File oldout) { 669 for (final ClusterType c : ClusterType.values()) { 670 for (final Precision p : Precision.values()) { 671 final ClusterTypeOp opts = (ClusterTypeOp) c.getOptions(); 672 opts.precision = p; 673 674 try { 675 if (IOUtils.readable(oldout, opts.getClusterClass())) 676 return opts; 677 } catch (final Exception e) { 678 679 } 680 } 681 } 682 683 return null; 684 } 685 686 /** 687 * Guess the type of the clusters based on the file header 688 * 689 * @param oldout 690 * @return guessed type 691 */ 692 public static ClusterTypeOp sniffClusterType(BufferedInputStream oldout) { 693 for (final ClusterType c : ClusterType.values()) { 694 for (final Precision p : Precision.values()) { 695 final ClusterTypeOp opts = (ClusterTypeOp) c.getOptions(); 696 opts.precision = p; 697 698 try { 699 if (IOUtils.readable(oldout, opts.getClusterClass())) { 700 return opts; 701 } 702 } catch (final Exception e) { 703 e.printStackTrace(); 704 } 705 } 706 } 707 708 return null; 709 } 710}