001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.rforest; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035import java.io.PrintWriter; 036import java.util.HashMap; 037import java.util.LinkedList; 038import java.util.List; 039import java.util.Map; 040import java.util.Map.Entry; 041import java.util.Random; 042import java.util.Scanner; 043 044import org.openimaj.citation.annotation.Reference; 045import org.openimaj.citation.annotation.ReferenceType; 046import org.openimaj.data.DataSource; 047import org.openimaj.ml.clustering.IndexClusters; 048import org.openimaj.ml.clustering.SpatialClusterer; 049import org.openimaj.ml.clustering.SpatialClusters; 050import org.openimaj.ml.clustering.assignment.HardAssigner; 051import org.openimaj.util.hash.HashCodeUtil; 052import org.openimaj.util.pair.IntFloatPair; 053 054/** 055 * An implementation of the RandomForest clustering algorithm proposed by <a 056 * href 057 * ="http://users.info.unicaen.fr/~jurie/papers/moosman-nowak-jurie-pami08.pdf" 058 * >Jurie et al</a>. 059 * <p> 060 * In this implementation the training phase is used to identify the limits of 061 * the data (for which a very small subset may be provided). Once this is known 062 * N decision trees are constructed each with M decisions (see 063 * {@link RandomDecisionTree}). In the clustering phase each feature projected 064 * is assigned a letter for each decision tree. 065 * 066 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 067 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 068 */ 069@Reference( 070 type = ReferenceType.Article, 071 author = { "Frank Moosmann", "Eric Nowak", "Fr{\'e}d{\'e}ric Jurie" }, 072 title = "Randomized Clustering Forests for Image Classification", 073 year = "2008", 074 journal = "IEEE PAMI", 075 url = "http://dx.doi.org/10.1109/TPAMI.2007.70822") 076public class IntRandomForest 077 implements 078 SpatialClusters<int[]>, 079 SpatialClusterer<IntRandomForest, int[]>, 080 HardAssigner<int[], float[], IntFloatPair> 081{ 082 private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RFIC"; 083 int nDecisions; 084 int nTrees; 085 int featureLength; 086 List<RandomDecisionTree> trees; 087 int[] maxVal; 088 int[] minVal; 089 Map<Letter, Integer> letterToInt; 090 private int currentInt = 0; 091 private HashMap<Word, Integer> wordToInt; 092 private int currentWordInt = 0; 093 private int randomSeed = -1; 094 095 private Word wordFromString(String str) { 096 final String[] values = str.split("_"); 097 final Letter[] intValues = new Letter[values.length]; 098 int i = 0; 099 for (final String s : values) { 100 intValues[i++] = letterFromString(s); 101 } 102 return new Word(intValues); 103 } 104 105 private Letter letterFromString(String str) { 106 final String[] values = str.split("-"); 107 final boolean[] intValues = new boolean[values.length]; 108 int i = 0; 109 for (int j = 0; j < values.length - 1; j++) { 110 intValues[i++] = Boolean.parseBoolean(values[j]); 111 } 112 return new Letter(intValues, Integer.parseInt(values[values.length - 1])); 113 } 114 115 class Word { 116 private Letter[] letters; 117 118 Word(Letter[] value) { 119 letters = value; 120 } 121 122 @Override 123 public int hashCode() { 124 int result = HashCodeUtil.SEED; 125 for (final Letter l : letters) { 126 result = HashCodeUtil.hash(result, l); 127 } 128 return result; 129 } 130 131 @Override 132 public boolean equals(Object obj) { 133 if (!(obj instanceof Word)) 134 return false; 135 136 final Word that = (Word) obj; 137 boolean same = true; 138 for (int i = 0; i < letters.length; i++) { 139 same &= letters[i].equals(that.letters[i]); 140 141 if (!same) 142 return false; 143 } 144 145 return same; 146 } 147 148 @Override 149 public String toString() { 150 String outString = ""; 151 for (int i = 0; i < this.letters.length; i++) { 152 outString += "_" + letters[i]; 153 } 154 return outString.substring(1); 155 } 156 157 public int hashedWord() { 158 if (!wordToInt.containsKey(this)) { 159 wordToInt.put(this, currentWordInt++); 160 if (currentWordInt == Integer.MAX_VALUE) { 161 System.err.println("Too many words!"); 162 currentWordInt = 0; 163 } 164 } 165 return wordToInt.get(this); 166 } 167 } 168 169 class Letter { 170 boolean[] value; 171 int treeIndex; 172 173 public Letter(boolean[] value, int treeIndex) { 174 this.value = value; 175 this.treeIndex = treeIndex; 176 } 177 178 @Override 179 public int hashCode() { 180 int result = HashCodeUtil.SEED; 181 result = HashCodeUtil.hash(result, this.value); 182 result = HashCodeUtil.hash(result, this.treeIndex); 183 return result; 184 } 185 186 public int getTreeIndex() { 187 return this.treeIndex; 188 } 189 190 @Override 191 public boolean equals(Object obj) { 192 if (!(obj instanceof Letter)) 193 return false; 194 final Letter that = (Letter) obj; 195 boolean same = true; 196 for (int i = 0; i < this.value.length; i++) { 197 same &= this.value[i] == that.value[i]; 198 if (!same) 199 return false; 200 } 201 same &= this.treeIndex == that.treeIndex; 202 return same; 203 } 204 205 @Override 206 public String toString() { 207 String outString = ""; 208 for (int i = 0; i < this.value.length; i++) { 209 outString += "-" + (value[i] ? 1 : 0); 210 } 211 return outString.substring(1) + "-" + this.treeIndex; 212 } 213 214 public int hashedLetter() { 215 if (!letterToInt.containsKey(this)) { 216 letterToInt.put(this, currentInt++); 217 if (currentInt == Integer.MAX_VALUE) { 218 System.err.println("... too many letters!"); 219 currentInt = 0; 220 } 221 } 222 return letterToInt.get(this); 223 } 224 } 225 226 /** 227 * Makes a default random forest with 32 trees each with 32 decisions. This 228 * results in 10^47 potential words. 229 */ 230 public IntRandomForest() { 231 this(32, 32); 232 } 233 234 /** 235 * Makes a random forest with nTrees each with nDecisions. This will result 236 * in nTrees ^ nDecisions potential words 237 * 238 * @param nTrees 239 * number of trees 240 * @param nDecisions 241 * number of decisions per tree 242 */ 243 public IntRandomForest(int nTrees, int nDecisions) { 244 this.nTrees = nTrees; 245 this.nDecisions = nDecisions; 246 this.letterToInt = new HashMap<Letter, Integer>(); 247 this.wordToInt = new HashMap<Word, Integer>(); 248 } 249 250 private void initMinMax(int[][] data) { 251 final int[] min = new int[this.featureLength]; 252 final int[] max = new int[this.featureLength]; 253 boolean isSet = false; 254 for (int i = 0; i < data.length; i++) { 255 for (int j = 0; j < this.featureLength; j++) { 256 final int val = data[i][j]; 257 if (!isSet) { 258 min[j] = val; 259 max[j] = val; 260 } else { 261 if (max[j] < val) 262 max[j] = val; 263 else if (min[j] > val) 264 min[j] = val; 265 } 266 } 267 isSet = true; 268 } 269 setMinMax(min, max); 270 } 271 272 /** 273 * The maximum and minimum values for the various dimentions against which 274 * random decisions will be based. 275 * 276 * @param min 277 * @param max 278 */ 279 public void setMinMax(int[] min, int[] max) { 280 this.minVal = min; 281 this.maxVal = max; 282 283 } 284 285 private void initTrees() { 286 this.trees = new LinkedList<RandomDecisionTree>(); 287 Random r = new Random(); 288 if (this.randomSeed != -1) 289 r = new Random(this.randomSeed); 290 for (int i = 0; i < nTrees; i++) { 291 final RandomDecisionTree tree = new RandomDecisionTree(this.nDecisions, this.featureLength, this.minVal, 292 this.maxVal, r); 293 this.trees.add(tree); 294 } 295 } 296 297 @Override 298 public IntRandomForest cluster(int[][] data) { 299 this.featureLength = data[0].length; 300 301 initMinMax(data); 302 initTrees(); 303 304 return this; 305 } 306 307 @Override 308 public IntRandomForest cluster(DataSource<int[]> data) { 309 final int[][] dataArr = new int[data.size()][data.numDimensions()]; 310 311 return cluster(dataArr); 312 } 313 314 @Override 315 public int numClusters() { 316 return this.currentInt; 317 } 318 319 @Override 320 public int numDimensions() { 321 return featureLength; 322 } 323 324 @Override 325 public int[] assign(int[][] data) { 326 final int[] proj = new int[data.length]; 327 328 for (int i = 0; i < data.length; i++) { 329 proj[i] = this.assign(data[i]); 330 } 331 return proj; 332 } 333 334 /** 335 * Push each data point provided to a set of letters, i.e. a word. Each 336 * letter represents a set of decisions made in a single decision tree. 337 * 338 * @param data 339 * @return A word per data point 340 */ 341 public Word[] assignLetters(int[][] data) { 342 final Word[] pushedLetters = new Word[data.length]; 343 344 int i = 0; 345 for (final int[] k : data) { 346 pushedLetters[i++] = this.assignWord(k); 347 } 348 349 return pushedLetters; 350 } 351 352 /** 353 * Push a single data point to a set of letters, return the letters as word. 354 * This is achieved by pushing the data point down each decision tree. This 355 * returns the result of the n decisions made for that tree, this is a 356 * single letter. As a letter is seen for the first time, it is assigned a 357 * number. The letters are combined in sequence to construct a word. 358 * 359 * @param data 360 * to be projected 361 * @return A single word containing the letters containing the decisions 362 * made on each tree 363 */ 364 public Word assignWord(int[] data) { 365 final Letter[] pushed = new Letter[this.nTrees]; 366 for (int i = 0; i < this.nTrees; i++) { 367 final boolean[] justLetter = this.trees.get(i).getLetter(data); 368 final Letter letter = new Letter(justLetter, i); 369 letter.hashedLetter(); 370 pushed[i] = letter; 371 } 372 return new Word(pushed); 373 } 374 375 /** 376 * Uses the {@link #assignWord(int[])} function to construct the word 377 * representing this data point. If this exact word has been seen before 378 * (i.e. these letters in this order) the same int is used. If not, a new 379 * int is assigned for this word. 380 * 381 * @param data 382 * a data point to be clustered to a word 383 * @return a cluster centroid from a word 384 */ 385 @Override 386 public int assign(int[] data) { 387 final Word word = this.assignWord(data); 388 return word.hashedWord(); 389 } 390 391 /** 392 * @return The number of decision trees 393 */ 394 public int getNTrees() { 395 return this.nTrees; 396 } 397 398 /** 399 * @return the number of decisions per tree 400 */ 401 public int getNDecisions() { 402 return this.nDecisions; 403 } 404 405 /** 406 * @return the decision trees 407 */ 408 public List<RandomDecisionTree> getTrees() { 409 return this.trees; 410 } 411 412 @Override 413 public boolean equals(Object r) { 414 if (!(r instanceof IntRandomForest)) 415 return false; 416 417 final IntRandomForest that = (IntRandomForest) r; 418 419 boolean same = true; 420 421 same &= this.numDimensions() == that.numDimensions(); 422 same &= this.getNTrees() == that.getNTrees(); 423 same &= this.getNDecisions() == that.getNDecisions(); 424 425 for (int i = 0; i < that.trees.size(); i++) { 426 this.trees.get(i).equals(that.trees.get(i)); 427 } 428 429 for (final Entry<Letter, Integer> a : that.letterToInt.entrySet()) { 430 same &= that.letterToInt.get(a.getKey()).equals(this.letterToInt.get(a.getKey())); 431 } 432 433 for (final Entry<Word, Integer> a : that.wordToInt.entrySet()) { 434 same &= that.wordToInt.get(a.getKey()).equals(this.wordToInt.get(a.getKey())); 435 } 436 437 return same; 438 } 439 440 @Override 441 public String asciiHeader() { 442 return "ASCII" + HEADER; 443 } 444 445 @Override 446 public byte[] binaryHeader() { 447 return HEADER.getBytes(); 448 } 449 450 @Override 451 public void readASCII(Scanner br) throws IOException { 452 nDecisions = Integer.parseInt(br.nextLine()); 453 nTrees = Integer.parseInt(br.nextLine()); 454 this.letterToInt = new HashMap<Letter, Integer>(); 455 featureLength = Integer.parseInt(br.nextLine()); 456 457 if (this.trees == null || this.trees.size() != nTrees) { 458 trees = new LinkedList<RandomDecisionTree>(); 459 for (int i = 0; i < nTrees; i++) 460 trees.add(new RandomDecisionTree().readASCII(br)); 461 } else { 462 // We have an existing tree, try to read it! 463 for (final RandomDecisionTree rt : trees) { 464 rt.readASCII(br); 465 } 466 } 467 468 // Only rebuild an array of the wrong size 469 String[] line = br.nextLine().split(" "); 470 if (maxVal == null || maxVal.length != featureLength) 471 maxVal = new int[featureLength]; 472 for (int i = 0; i < featureLength; i++) 473 maxVal[i] = Integer.parseInt(line[i]); 474 if (minVal == null || minVal.length != featureLength) 475 minVal = new int[featureLength]; 476 line = br.nextLine().split(" "); 477 for (int i = 0; i < featureLength; i++) 478 minVal[i] = Integer.parseInt(line[i]); 479 currentInt = Integer.parseInt(br.nextLine()); 480 481 line = br.nextLine().split(" "); 482 assert ((line.length - 1) == currentInt); 483 for (int i = 0; i < currentInt; i++) { 484 final String[] part = line[i].split(","); 485 486 letterToInt.put(letterFromString(part[0]), Integer.parseInt(part[1])); 487 } 488 currentWordInt = Integer.parseInt(br.nextLine()); 489 if (currentWordInt != 0) { 490 line = br.nextLine().split(" "); 491 assert ((line.length - 1) == currentWordInt); 492 for (int i = 0; i < currentWordInt; i++) { 493 final String[] part = line[i].split(","); 494 495 wordToInt.put(wordFromString(part[0]), Integer.parseInt(part[1])); 496 } 497 } 498 } 499 500 @Override 501 public void readBinary(DataInput dis) throws IOException { 502 nDecisions = dis.readInt(); 503 nTrees = dis.readInt(); 504 this.letterToInt = new HashMap<Letter, Integer>(); 505 featureLength = dis.readInt(); 506 if (this.trees == null || this.trees.size() != nTrees) { 507 trees = new LinkedList<RandomDecisionTree>(); 508 for (int i = 0; i < nTrees; i++) 509 trees.add(new RandomDecisionTree().readBinary(dis)); 510 } else { 511 // We have an existing tree, try to read it! 512 for (final RandomDecisionTree rt : trees) { 513 rt.readBinary(dis); 514 } 515 } 516 517 if (maxVal == null || maxVal.length != featureLength) 518 maxVal = new int[featureLength]; 519 for (int i = 0; i < featureLength; i++) 520 maxVal[i] = dis.readInt(); 521 522 if (minVal == null || minVal.length != featureLength) 523 minVal = new int[featureLength]; 524 for (int i = 0; i < featureLength; i++) 525 minVal[i] = dis.readInt(); 526 527 currentInt = dis.readInt(); 528 for (int i = 0; i < currentInt; i++) { 529 final int letterLen = dis.readInt(); 530 final boolean[] stringBytes = new boolean[letterLen]; 531 // dis.read(stringBytes, 0, stringBytes .length); // Entry key 532 for (int j = 0; j < letterLen; j++) 533 stringBytes[j] = dis.readBoolean(); 534 letterToInt.put(new Letter(stringBytes, dis.readInt()), dis.readInt()); 535 } 536 537 currentWordInt = dis.readInt(); 538 if (currentWordInt != 0) { 539 for (int i = 0; i < currentWordInt; i++) { 540 final Letter[] letters = new Letter[dis.readInt()]; 541 // dis.read(stringBytes, 0, stringBytes .length); // Entry key 542 for (int j = 0; j < letters.length; j++) { 543 final int letterLen = dis.readInt(); 544 final boolean[] stringBytes = new boolean[letterLen]; 545 for (int k = 0; k < stringBytes.length; k++) { 546 stringBytes[k] = dis.readBoolean(); 547 } 548 letters[j] = new Letter(stringBytes, dis.readInt()); 549 } 550 wordToInt.put(new Word(letters), dis.readInt()); 551 } 552 } 553 } 554 555 @Override 556 public void writeASCII(PrintWriter writer) throws IOException { 557 writer.println(this.nDecisions); 558 writer.println(this.nTrees); 559 writer.println(this.featureLength); 560 for (final RandomDecisionTree tree : trees) { 561 tree.writeASCII(writer); 562 writer.println(); 563 } 564 for (int i = 0; i < maxVal.length; i++) 565 writer.print(maxVal[i] + " "); 566 writer.println(); 567 for (int i = 0; i < minVal.length; i++) 568 writer.print(minVal[i] + " "); 569 writer.println(); 570 writer.println(currentInt); 571 for (final Entry<Letter, Integer> p : letterToInt.entrySet()) { 572 writer.print(p.getKey() + "," + p.getValue() + " "); 573 } 574 writer.println(); 575 writer.println(currentWordInt); 576 for (final Entry<Word, Integer> p : wordToInt.entrySet()) { 577 writer.print(p.getKey() + "," + p.getValue() + " "); 578 } 579 writer.println(); 580 writer.flush(); 581 } 582 583 @Override 584 public void writeBinary(DataOutput o) throws IOException { 585 o.writeInt(nDecisions); 586 o.writeInt(nTrees); 587 o.writeInt(featureLength); 588 for (final RandomDecisionTree tree : trees) { 589 tree.write(o); 590 } 591 592 for (int i = 0; i < maxVal.length; i++) 593 o.writeInt(maxVal[i]); 594 for (int i = 0; i < minVal.length; i++) 595 o.writeInt(minVal[i]); 596 597 o.writeInt(currentInt); 598 for (final Entry<Letter, Integer> p : letterToInt.entrySet()) { 599 o.writeInt(p.getKey().value.length); 600 for (final boolean i : p.getKey().value) 601 o.writeBoolean(i); 602 o.writeInt(p.getKey().treeIndex); 603 o.writeInt(p.getValue()); 604 } 605 o.writeInt(currentWordInt); 606 for (final Entry<Word, Integer> p : wordToInt.entrySet()) { 607 o.writeInt(p.getKey().letters.length); 608 for (final Letter i : p.getKey().letters) { 609 o.writeInt(i.value.length); 610 for (final boolean j : i.value) 611 o.writeBoolean(j); 612 o.writeInt(i.treeIndex); 613 } 614 o.writeInt(p.getValue()); 615 } 616 } 617 618 /** 619 * @param random 620 * the seed of the java {@link Random} instance used by the 621 * decision trees 622 */ 623 public void setRandomSeed(int random) { 624 this.randomSeed = random; 625 } 626 627 Letter newLetter(boolean[] bs, int i) { 628 return new Letter(bs, i); 629 } 630 631 Word newWord(Letter[] letters) { 632 return new Word(letters); 633 } 634 635 @Override 636 public void assignDistance(int[][] data, int[] indices, float[] distances) { 637 throw new UnsupportedOperationException("Not implemented"); 638 } 639 640 @Override 641 public IntFloatPair assignDistance(int[] data) { 642 throw new UnsupportedOperationException("Not implemented"); 643 } 644 645 @Override 646 public HardAssigner<int[], ?, ?> defaultHardAssigner() { 647 return this; 648 } 649 650 @Override 651 public int size() { 652 return this.currentInt; 653 } 654 655 @Override 656 public int[][] performClustering(int[][] data) { 657 return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters(); 658 } 659}