001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.tools.clusterquantiser; 031 032import gnu.trove.list.array.TIntArrayList; 033 034import java.io.File; 035import java.io.FileOutputStream; 036import java.io.FileWriter; 037import java.io.IOException; 038import java.io.ObjectOutputStream; 039import java.io.PrintWriter; 040import java.util.ArrayList; 041import java.util.List; 042import java.util.Random; 043import java.util.concurrent.Callable; 044import java.util.concurrent.ExecutorService; 045import java.util.concurrent.Executors; 046 047import org.kohsuke.args4j.CmdLineException; 048import org.openimaj.data.RandomData; 049import org.openimaj.io.IOUtils; 050import org.openimaj.ml.clustering.ByteCentroidsResult; 051import org.openimaj.ml.clustering.IntCentroidsResult; 052import org.openimaj.ml.clustering.SpatialClusters; 053import org.openimaj.ml.clustering.assignment.HardAssigner; 054import org.openimaj.ml.clustering.assignment.hard.KDTreeByteEuclideanAssigner; 055import org.openimaj.ml.clustering.assignment.hard.KDTreeIntEuclideanAssigner; 056import org.openimaj.time.Timer; 057import org.openimaj.tools.clusterquantiser.ClusterType.ClusterTypeOp; 058import org.openimaj.tools.clusterquantiser.samplebatch.SampleBatch; 059import org.openimaj.util.array.ByteArrayConverter; 060import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory; 061 062/** 063 * A tool for clustering and quantising local features. 064 * 065 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 066 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 067 */ 068public class ClusterQuantiser { 069 /** 070 * create new clusters 071 * 072 * @param options 073 * @return clusters 074 * @throws Exception 075 */ 076 public static SpatialClusters<?> do_create(ClusterQuantiserOptions options) throws Exception { 077 final File treeFile = new File(options.getTreeFile()); 078 final ClusterTypeOp clusterType = options.getClusterType(); 079 080 SpatialClusters<?> cluster = null; 081 082 // perform sampling if required 083 if (options.isBatchedSampleMode()) { 084 cluster = clusterType.create(do_getSampleBatches(options)); 085 IOUtils.writeBinary(treeFile, cluster); 086 } else { 087 final byte[][] data = do_getSamples(options); 088 System.err.printf("Using %d records\n", data.length); 089 cluster = clusterType.create(data); 090 System.err.println("Writing cluster file to " + treeFile); 091 IOUtils.writeBinary(treeFile, cluster); 092 } 093 return cluster; 094 095 } 096 097 /** 098 * Get sample batches 099 * 100 * @param options 101 * @return batches 102 * @throws IOException 103 */ 104 public static List<SampleBatch> do_getSampleBatches(ClusterQuantiserOptions options) throws IOException { 105 if (options.isSamplesFileMode()) { 106 try { 107 System.err.println("Attempting to read sample batch file..."); 108 return SampleBatch.readSampleBatches(options.getSamplesFile()); 109 } catch (final Exception e) { 110 System.err.println("... Failed! "); 111 return null; 112 } 113 } 114 115 final List<SampleBatch> batches = new ArrayList<SampleBatch>(); 116 final List<File> input_files = options.getInputFiles(); 117 final FileType type = options.getFileType(); 118 final int n_input_files = input_files.size(); 119 final List<Header> headers = new ArrayList<Header>(n_input_files); 120 121 // read the headers and count the total number of features 122 System.err.printf("Reading input %8d / %8d", 0, n_input_files); 123 int totalFeatures = 0; 124 final int[] cumSum = new int[n_input_files + 1]; 125 for (int i = 0; i < n_input_files; i++) { 126 final Header h = type.readHeader(input_files.get(i)); 127 128 totalFeatures += h.nfeatures; 129 cumSum[i + 1] = totalFeatures; 130 headers.add(h); 131 132 System.err.printf("\r%8d / %8d", i + 1, n_input_files); 133 } 134 135 System.err.println(); 136 final int samples = options.getSamples(); 137 if (samples <= 0 || samples > totalFeatures) { 138 System.err.printf( 139 "Samples requested %d larger than total samples %d...\n", 140 samples, totalFeatures); 141 142 for (int i = 0; i < n_input_files; i++) { 143 if (cumSum[i + 1] - cumSum[i] == 0) 144 continue; 145 final SampleBatch sb = new SampleBatch(type, input_files.get(i), 146 cumSum[i], cumSum[i + 1]); 147 batches.add(sb); 148 System.err.printf("\rConstructing sample batches %8d / %8d", i, 149 n_input_files); 150 } 151 System.err.println(); 152 System.err.println("Done..."); 153 } else { 154 System.err.println("Shuffling and sampling ..."); 155 // generate sample unique random numbers between 0 and totalFeatures 156 int[] rndIndices = null; 157 if (options.getRandomSeed() == -1) 158 rndIndices = RandomData.getUniqueRandomInts(samples, 0, 159 totalFeatures); 160 else 161 rndIndices = RandomData.getUniqueRandomInts(samples, 0, 162 totalFeatures, new Random(options.getRandomSeed())); 163 System.err.println("Done! Extracting features required"); 164 final TIntArrayList intraFileIndices = new TIntArrayList(); 165 for (int j = 0, s = 0; j < n_input_files; j++) { 166 intraFileIndices.clear(); 167 168 // go through samples and find ones belonging to this doc 169 for (int i = 0; i < samples; i++) { 170 final int idx = rndIndices[i]; 171 172 if (idx >= cumSum[j] && idx < cumSum[j + 1]) { 173 intraFileIndices.add(idx - cumSum[j]); 174 } 175 } 176 177 if (intraFileIndices.size() > 0) { 178 final SampleBatch sb = new SampleBatch(type, input_files.get(j), 179 s, s + intraFileIndices.size(), 180 intraFileIndices.toArray()); 181 batches.add(sb); 182 s += intraFileIndices.size(); 183 System.err.printf("\r%8d / %8d", s, samples); 184 } 185 186 } 187 System.err.println(); 188 } 189 if (batches.size() > 0 && options.getSamplesFile() != null) { 190 System.err.println("Writing samples file..."); 191 SampleBatch.writeSampleBatches(batches, options.getSamplesFile()); 192 } 193 return batches; 194 } 195 196 /** 197 * Get samples 198 * 199 * @param options 200 * @return samples 201 * @throws IOException 202 */ 203 public static byte[][] do_getSamples(ClusterQuantiserOptions options) 204 throws IOException 205 { 206 207 byte[][] data = null; 208 if (options.isSamplesFileMode()) { 209 data = options.getSampleKeypoints(); 210 } else { 211 final List<File> input_files = options.getInputFiles(); 212 final FileType type = options.getFileType(); 213 final int n_input_files = input_files.size(); 214 final List<Header> headers = new ArrayList<Header>(n_input_files); 215 216 // read the headers and count the total number of features 217 System.err.printf("Reading input %8d / %8d", 0, n_input_files); 218 int totalFeatures = 0; 219 final int[] cumSum = new int[n_input_files + 1]; 220 for (int i = 0; i < n_input_files; i++) { 221 final Header h = type.readHeader(input_files.get(i)); 222 223 totalFeatures += h.nfeatures; 224 cumSum[i + 1] = totalFeatures; 225 headers.add(h); 226 227 System.err.printf("\r%8d / %8d", i + 1, n_input_files); 228 } 229 230 System.err.println(); 231 final int samples = options.getSamples(); 232 if (samples <= 0 || samples > totalFeatures) { 233 System.err 234 .printf("Samples requested %d larger than total samples %d...\n", 235 samples, totalFeatures); 236 // no sampled requested or more samples requested than features 237 // exist 238 // so use all features 239 data = new byte[totalFeatures][]; 240 241 for (int i = 0, j = 0; i < n_input_files; i++) { 242 final byte[][] fd = type.readFeatures(input_files.get(i)); 243 244 for (int k = 0; k < fd.length; k++) { 245 data[j + k] = fd[k]; 246 System.err.printf("\r%8d / %8d", j, totalFeatures); 247 } 248 j += fd.length; 249 } 250 } else { 251 System.err.println("Shuffling and sampling ..."); 252 253 data = new byte[samples][]; 254 255 // generate sample unique random numbers between 0 and 256 // totalFeatures 257 int[] rndIndices = null; 258 if (options.getRandomSeed() == -1) 259 rndIndices = RandomData.getUniqueRandomInts(samples, 0, 260 totalFeatures); 261 else 262 rndIndices = RandomData.getUniqueRandomInts(samples, 0, 263 totalFeatures, new Random(options.getRandomSeed())); 264 System.err.println("Done! Extracting features required"); 265 final TIntArrayList intraFileIndices = new TIntArrayList(); 266 for (int j = 0, s = 0; j < n_input_files; j++) { 267 intraFileIndices.clear(); 268 269 // go through samples and find ones belonging to this doc 270 for (int i = 0; i < samples; i++) { 271 final int idx = rndIndices[i]; 272 273 if (idx >= cumSum[j] && idx < cumSum[j + 1]) { 274 intraFileIndices.add(idx - cumSum[j]); 275 } 276 } 277 278 if (intraFileIndices.size() > 0) { 279 final byte[][] f = type.readFeatures(input_files.get(j), 280 intraFileIndices.toArray()); 281 for (int i = 0; i < intraFileIndices.size(); i++, s++) { 282 data[s] = f[i]; 283 System.err 284 .printf("\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d", 285 s + 1, samples); 286 } 287 } 288 289 } 290 System.err.println(); 291 } 292 if (data != null && options.getSamplesFile() != null) { 293 System.err.println("Writing samples file..."); 294 final FileOutputStream fos = new FileOutputStream( 295 options.getSamplesFile()); 296 final ObjectOutputStream dos = new ObjectOutputStream(fos); 297 dos.writeObject(data); 298 dos.close(); 299 } 300 } 301 return data; 302 } 303 304 /** 305 * Print info about clusters 306 * 307 * @param options 308 * @throws IOException 309 */ 310 public static void do_info(AbstractClusterQuantiserOptions options) 311 throws IOException 312 { 313 final SpatialClusters<?> cluster = IOUtils.read(new File(options.getTreeFile()), options.getClusterClass()); 314 System.out.println("Cluster loaded..."); 315 System.out.println(cluster); 316 } 317 318 /** 319 * Quantise features 320 * 321 * @param cqo 322 * @throws IOException 323 * @throws InterruptedException 324 */ 325 public static void do_quant(ClusterQuantiserOptions cqo) throws IOException, InterruptedException { 326 final ExecutorService es = Executors.newFixedThreadPool(cqo.getConcurrency(), new DaemonThreadFactory()); 327 328 final List<QuantizerJob> jobs = QuantizerJob.getJobs(cqo); 329 330 System.out.format("Using %d processors\n", cqo.getConcurrency()); 331 es.invokeAll(jobs); 332 es.shutdown(); 333 } 334 335 static class QuantizerJob implements Callable<Boolean> { 336 SpatialClusters<?> tree; 337 HardAssigner<?, ?, ?> assigner; 338 339 List<File> inputFiles; 340 // FileType fileType; 341 // String extension; 342 // private ClusterType ctype; 343 // private File outputFile; 344 private ClusterQuantiserOptions cqo; 345 private String commonRoot; 346 347 static int count = 0; 348 static int total; 349 350 static synchronized void incr() { 351 count++; 352 System.err.printf( 353 "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d", count, 354 total); 355 } 356 357 protected QuantizerJob(ClusterQuantiserOptions cqo, SpatialClusters<?> tree, HardAssigner<?, ?, ?> assigner) 358 throws IOException 359 { 360 this.cqo = cqo; 361 this.tree = tree; 362 this.inputFiles = cqo.getInputFiles(); 363 this.commonRoot = cqo.getInputFileCommonRoot(); 364 this.assigner = assigner; 365 } 366 367 protected QuantizerJob(ClusterQuantiserOptions cqo, 368 List<File> inputFiles, SpatialClusters<?> clusters, HardAssigner<?, ?, ?> assigner) throws IOException 369 { 370 this.cqo = cqo; 371 this.tree = clusters; 372 this.inputFiles = inputFiles; 373 this.commonRoot = cqo.getInputFileCommonRoot(); 374 this.assigner = assigner; 375 } 376 377 public static List<QuantizerJob> getJobs(ClusterQuantiserOptions cqo) 378 throws IOException 379 { 380 381 final List<QuantizerJob> jobs = new ArrayList<QuantizerJob>( 382 cqo.getConcurrency()); 383 final int size = cqo.getInputFiles().size() / cqo.getConcurrency(); 384 385 final SpatialClusters<?> clusters = IOUtils.read(new File(cqo.getTreeFile()), cqo.getClusterClass()); 386 387 HardAssigner<?, ?, ?> assigner; 388 if (!cqo.exactQuant) { 389 assigner = clusters.defaultHardAssigner(); 390 } else { 391 if (clusters instanceof ByteCentroidsResult) 392 assigner = new KDTreeByteEuclideanAssigner((ByteCentroidsResult) clusters); 393 else 394 assigner = new KDTreeIntEuclideanAssigner((IntCentroidsResult) clusters); 395 } 396 397 QuantizerJob.count = 0; 398 QuantizerJob.total = cqo.getInputFiles().size(); 399 for (int i = 0; i < cqo.getConcurrency() - 1; i++) { 400 final QuantizerJob job = new QuantizerJob(cqo, cqo.getInputFiles().subList(i * size, (i + 1) * size), 401 clusters, assigner); 402 jobs.add(job); 403 } 404 // add remaining 405 final QuantizerJob job = new QuantizerJob(cqo, 406 cqo.getInputFiles().subList((cqo.getConcurrency() - 1) * size, 407 cqo.getInputFiles().size()), clusters, assigner); 408 jobs.add(job); 409 410 return jobs; 411 } 412 413 @SuppressWarnings("unchecked") 414 @Override 415 public Boolean call() throws Exception { 416 for (int i = 0; i < inputFiles.size(); i++) { 417 try { 418 File outFile = new File(inputFiles.get(i) 419 + cqo.getExtension()); 420 if (cqo.getOutputFile() != null) 421 outFile = new File(cqo.getOutputFile().getAbsolutePath() // /output 422 + File.separator // / 423 + outFile.getAbsolutePath().substring(this.commonRoot.length())); // /filename.out 424 if (outFile.exists() && outFile.getTotalSpace() > 0) { 425 incr(); 426 continue; 427 } 428 final FeatureFile input = cqo.getFileType().read(inputFiles.get(i)); 429 PrintWriter pw = null; 430 // Make the parent directory if you need to 431 if (!outFile.getParentFile().exists()) { 432 if (!outFile.getParentFile().mkdirs()) 433 throw new IOException("couldn't make output directory: " + outFile.getParentFile()); 434 } 435 final Timer t = Timer.timer(); 436 try { 437 pw = new PrintWriter(new FileWriter(outFile)); 438 pw.format("%d\n%d\n", input.size(), 439 tree.numClusters()); 440 // int [] clusters = new int[input.size()]; 441 for (final FeatureFileFeature fff : input) { 442 int cluster = -1; 443 if (tree.getClass().getName().contains("Byte")) 444 cluster = ((HardAssigner<byte[], ?, ?>) assigner).assign(fff.data); 445 else 446 cluster = ((HardAssigner<int[], ?, ?>) assigner).assign(ByteArrayConverter 447 .byteToInt(fff.data)); 448 pw.format("%s %d\n", fff.location.trim(), cluster); 449 } 450 } catch (final IOException e) { 451 e.printStackTrace(); 452 throw new Error(e); // IO error when writing - die. 453 } finally { 454 if (pw != null) { 455 pw.flush(); 456 pw.close(); 457 input.close(); 458 } 459 460 } 461 t.stop(); 462 if (cqo.printTiming()) { 463 System.out.println("Took: " + t.duration()); 464 } 465 466 } catch (final Exception e) { 467 // Error processing an individual file; print error then 468 // continue 469 e.printStackTrace(); 470 System.err.println("Error processing file:" 471 + inputFiles.get(i)); 472 System.err 473 .println("(Exception was " + e.getMessage() + ")"); 474 } 475 476 // System.err.printf("\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d", 477 // i+1, input_files.size()); 478 incr(); 479 } 480 // System.out.println(); 481 return true; 482 } 483 } 484 485 /** 486 * Prepare options 487 * 488 * @param args 489 * @return prepared options 490 * @throws InterruptedException 491 * @throws CmdLineException 492 */ 493 public static ClusterQuantiserOptions mainOptions(String[] args) 494 throws InterruptedException, CmdLineException 495 { 496 final ClusterQuantiserOptions options = new ClusterQuantiserOptions(args); 497 options.prepare(); 498 499 return options; 500 } 501 502 /** 503 * The main method 504 * 505 * @param args 506 * @throws Exception 507 */ 508 public static void main(String[] args) throws Exception { 509 try { 510 final ClusterQuantiserOptions options = mainOptions(args); 511 512 final List<File> inputFiles = options.getInputFiles(); 513 514 if (options.getVerbosity() >= 0 && !options.isInfoMode()) 515 System.err 516 .printf("We have %d input files\n", inputFiles.size()); 517 518 if (options.isCreateMode()) { 519 do_create(options); 520 } else if (options.isInfoMode()) { 521 do_info(options); 522 } else if (options.isQuantMode()) { 523 do_quant(options); 524 } else if (options.getSamplesFile() != null 525 && inputFiles.size() > 0) 526 { 527 if (options.isBatchedSampleMode()) { 528 do_getSampleBatches(options); 529 } else { 530 do_getSamples(options); 531 } 532 } 533 } catch (final CmdLineException cmdline) { 534 System.err.print(cmdline); 535 } catch (final IOException e) { 536 System.err.println(e.getMessage()); 537 } 538 } 539}