View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.clustering.rforest;
31  
32  import java.io.DataInput;
33  import java.io.DataOutput;
34  import java.io.IOException;
35  import java.io.PrintWriter;
36  import java.util.HashMap;
37  import java.util.LinkedList;
38  import java.util.List;
39  import java.util.Map;
40  import java.util.Map.Entry;
41  import java.util.Random;
42  import java.util.Scanner;
43  
44  import org.openimaj.citation.annotation.Reference;
45  import org.openimaj.citation.annotation.ReferenceType;
46  import org.openimaj.data.DataSource;
47  import org.openimaj.ml.clustering.IndexClusters;
48  import org.openimaj.ml.clustering.SpatialClusterer;
49  import org.openimaj.ml.clustering.SpatialClusters;
50  import org.openimaj.ml.clustering.assignment.HardAssigner;
51  import org.openimaj.util.hash.HashCodeUtil;
52  import org.openimaj.util.pair.IntFloatPair;
53  
54  /**
55   * An implementation of the RandomForest clustering algorithm proposed by <a
56   * href
57   * ="http://users.info.unicaen.fr/~jurie/papers/moosman-nowak-jurie-pami08.pdf"
58   * >Jurie et al</a>.
59   * <p>
60   * In this implementation the training phase is used to identify the limits of
61   * the data (for which a very small subset may be provided). Once this is known
62   * N decision trees are constructed each with M decisions (see
63   * {@link RandomDecisionTree}). In the clustering phase each feature projected
64   * is assigned a letter for each decision tree.
65   * 
66   * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
67   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
68   */
69  @Reference(
70  		type = ReferenceType.Article,
71  		author = { "Frank Moosmann", "Eric Nowak", "Fr{\'e}d{\'e}ric Jurie" },
72  		title = "Randomized Clustering Forests for Image Classification",
73  		year = "2008",
74  		journal = "IEEE PAMI",
75  		url = "http://dx.doi.org/10.1109/TPAMI.2007.70822")
76  public class IntRandomForest
77  		implements
78  		SpatialClusters<int[]>,
79  		SpatialClusterer<IntRandomForest, int[]>,
80  		HardAssigner<int[], float[], IntFloatPair>
81  {
82  	private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RFIC";
83  	int nDecisions;
84  	int nTrees;
85  	int featureLength;
86  	List<RandomDecisionTree> trees;
87  	int[] maxVal;
88  	int[] minVal;
89  	Map<Letter, Integer> letterToInt;
90  	private int currentInt = 0;
91  	private HashMap<Word, Integer> wordToInt;
92  	private int currentWordInt = 0;
93  	private int randomSeed = -1;
94  
95  	private Word wordFromString(String str) {
96  		final String[] values = str.split("_");
97  		final Letter[] intValues = new Letter[values.length];
98  		int i = 0;
99  		for (final String s : values) {
100 			intValues[i++] = letterFromString(s);
101 		}
102 		return new Word(intValues);
103 	}
104 
105 	private Letter letterFromString(String str) {
106 		final String[] values = str.split("-");
107 		final boolean[] intValues = new boolean[values.length];
108 		int i = 0;
109 		for (int j = 0; j < values.length - 1; j++) {
110 			intValues[i++] = Boolean.parseBoolean(values[j]);
111 		}
112 		return new Letter(intValues, Integer.parseInt(values[values.length - 1]));
113 	}
114 
115 	class Word {
116 		private Letter[] letters;
117 
118 		Word(Letter[] value) {
119 			letters = value;
120 		}
121 
122 		@Override
123 		public int hashCode() {
124 			int result = HashCodeUtil.SEED;
125 			for (final Letter l : letters) {
126 				result = HashCodeUtil.hash(result, l);
127 			}
128 			return result;
129 		}
130 
131 		@Override
132 		public boolean equals(Object obj) {
133 			if (!(obj instanceof Word))
134 				return false;
135 
136 			final Word that = (Word) obj;
137 			boolean same = true;
138 			for (int i = 0; i < letters.length; i++) {
139 				same &= letters[i].equals(that.letters[i]);
140 
141 				if (!same)
142 					return false;
143 			}
144 
145 			return same;
146 		}
147 
148 		@Override
149 		public String toString() {
150 			String outString = "";
151 			for (int i = 0; i < this.letters.length; i++) {
152 				outString += "_" + letters[i];
153 			}
154 			return outString.substring(1);
155 		}
156 
157 		public int hashedWord() {
158 			if (!wordToInt.containsKey(this)) {
159 				wordToInt.put(this, currentWordInt++);
160 				if (currentWordInt == Integer.MAX_VALUE) {
161 					System.err.println("Too many words!");
162 					currentWordInt = 0;
163 				}
164 			}
165 			return wordToInt.get(this);
166 		}
167 	}
168 
169 	class Letter {
170 		boolean[] value;
171 		int treeIndex;
172 
173 		public Letter(boolean[] value, int treeIndex) {
174 			this.value = value;
175 			this.treeIndex = treeIndex;
176 		}
177 
178 		@Override
179 		public int hashCode() {
180 			int result = HashCodeUtil.SEED;
181 			result = HashCodeUtil.hash(result, this.value);
182 			result = HashCodeUtil.hash(result, this.treeIndex);
183 			return result;
184 		}
185 
186 		public int getTreeIndex() {
187 			return this.treeIndex;
188 		}
189 
190 		@Override
191 		public boolean equals(Object obj) {
192 			if (!(obj instanceof Letter))
193 				return false;
194 			final Letter that = (Letter) obj;
195 			boolean same = true;
196 			for (int i = 0; i < this.value.length; i++) {
197 				same &= this.value[i] == that.value[i];
198 				if (!same)
199 					return false;
200 			}
201 			same &= this.treeIndex == that.treeIndex;
202 			return same;
203 		}
204 
205 		@Override
206 		public String toString() {
207 			String outString = "";
208 			for (int i = 0; i < this.value.length; i++) {
209 				outString += "-" + (value[i] ? 1 : 0);
210 			}
211 			return outString.substring(1) + "-" + this.treeIndex;
212 		}
213 
214 		public int hashedLetter() {
215 			if (!letterToInt.containsKey(this)) {
216 				letterToInt.put(this, currentInt++);
217 				if (currentInt == Integer.MAX_VALUE) {
218 					System.err.println("... too many letters!");
219 					currentInt = 0;
220 				}
221 			}
222 			return letterToInt.get(this);
223 		}
224 	}
225 
226 	/**
227 	 * Makes a default random forest with 32 trees each with 32 decisions. This
228 	 * results in 10^47 potential words.
229 	 */
230 	public IntRandomForest() {
231 		this(32, 32);
232 	}
233 
234 	/**
235 	 * Makes a random forest with nTrees each with nDecisions. This will result
236 	 * in nTrees ^ nDecisions potential words
237 	 * 
238 	 * @param nTrees
239 	 *            number of trees
240 	 * @param nDecisions
241 	 *            number of decisions per tree
242 	 */
243 	public IntRandomForest(int nTrees, int nDecisions) {
244 		this.nTrees = nTrees;
245 		this.nDecisions = nDecisions;
246 		this.letterToInt = new HashMap<Letter, Integer>();
247 		this.wordToInt = new HashMap<Word, Integer>();
248 	}
249 
250 	private void initMinMax(int[][] data) {
251 		final int[] min = new int[this.featureLength];
252 		final int[] max = new int[this.featureLength];
253 		boolean isSet = false;
254 		for (int i = 0; i < data.length; i++) {
255 			for (int j = 0; j < this.featureLength; j++) {
256 				final int val = data[i][j];
257 				if (!isSet) {
258 					min[j] = val;
259 					max[j] = val;
260 				} else {
261 					if (max[j] < val)
262 						max[j] = val;
263 					else if (min[j] > val)
264 						min[j] = val;
265 				}
266 			}
267 			isSet = true;
268 		}
269 		setMinMax(min, max);
270 	}
271 
272 	/**
273 	 * The maximum and minimum values for the various dimentions against which
274 	 * random decisions will be based.
275 	 * 
276 	 * @param min
277 	 * @param max
278 	 */
279 	public void setMinMax(int[] min, int[] max) {
280 		this.minVal = min;
281 		this.maxVal = max;
282 
283 	}
284 
285 	private void initTrees() {
286 		this.trees = new LinkedList<RandomDecisionTree>();
287 		Random r = new Random();
288 		if (this.randomSeed != -1)
289 			r = new Random(this.randomSeed);
290 		for (int i = 0; i < nTrees; i++) {
291 			final RandomDecisionTree tree = new RandomDecisionTree(this.nDecisions, this.featureLength, this.minVal,
292 					this.maxVal, r);
293 			this.trees.add(tree);
294 		}
295 	}
296 
297 	@Override
298 	public IntRandomForest cluster(int[][] data) {
299 		this.featureLength = data[0].length;
300 
301 		initMinMax(data);
302 		initTrees();
303 
304 		return this;
305 	}
306 
307 	@Override
308 	public IntRandomForest cluster(DataSource<int[]> data) {
309 		final int[][] dataArr = new int[data.size()][data.numDimensions()];
310 
311 		return cluster(dataArr);
312 	}
313 
314 	@Override
315 	public int numClusters() {
316 		return this.currentInt;
317 	}
318 
319 	@Override
320 	public int numDimensions() {
321 		return featureLength;
322 	}
323 
324 	@Override
325 	public int[] assign(int[][] data) {
326 		final int[] proj = new int[data.length];
327 
328 		for (int i = 0; i < data.length; i++) {
329 			proj[i] = this.assign(data[i]);
330 		}
331 		return proj;
332 	}
333 
334 	/**
335 	 * Push each data point provided to a set of letters, i.e. a word. Each
336 	 * letter represents a set of decisions made in a single decision tree.
337 	 * 
338 	 * @param data
339 	 * @return A word per data point
340 	 */
341 	public Word[] assignLetters(int[][] data) {
342 		final Word[] pushedLetters = new Word[data.length];
343 
344 		int i = 0;
345 		for (final int[] k : data) {
346 			pushedLetters[i++] = this.assignWord(k);
347 		}
348 
349 		return pushedLetters;
350 	}
351 
352 	/**
353 	 * Push a single data point to a set of letters, return the letters as word.
354 	 * This is achieved by pushing the data point down each decision tree. This
355 	 * returns the result of the n decisions made for that tree, this is a
356 	 * single letter. As a letter is seen for the first time, it is assigned a
357 	 * number. The letters are combined in sequence to construct a word.
358 	 * 
359 	 * @param data
360 	 *            to be projected
361 	 * @return A single word containing the letters containing the decisions
362 	 *         made on each tree
363 	 */
364 	public Word assignWord(int[] data) {
365 		final Letter[] pushed = new Letter[this.nTrees];
366 		for (int i = 0; i < this.nTrees; i++) {
367 			final boolean[] justLetter = this.trees.get(i).getLetter(data);
368 			final Letter letter = new Letter(justLetter, i);
369 			letter.hashedLetter();
370 			pushed[i] = letter;
371 		}
372 		return new Word(pushed);
373 	}
374 
375 	/**
376 	 * Uses the {@link #assignWord(int[])} function to construct the word
377 	 * representing this data point. If this exact word has been seen before
378 	 * (i.e. these letters in this order) the same int is used. If not, a new
379 	 * int is assigned for this word.
380 	 * 
381 	 * @param data
382 	 *            a data point to be clustered to a word
383 	 * @return a cluster centroid from a word
384 	 */
385 	@Override
386 	public int assign(int[] data) {
387 		final Word word = this.assignWord(data);
388 		return word.hashedWord();
389 	}
390 
391 	/**
392 	 * @return The number of decision trees
393 	 */
394 	public int getNTrees() {
395 		return this.nTrees;
396 	}
397 
398 	/**
399 	 * @return the number of decisions per tree
400 	 */
401 	public int getNDecisions() {
402 		return this.nDecisions;
403 	}
404 
405 	/**
406 	 * @return the decision trees
407 	 */
408 	public List<RandomDecisionTree> getTrees() {
409 		return this.trees;
410 	}
411 
412 	@Override
413 	public boolean equals(Object r) {
414 		if (!(r instanceof IntRandomForest))
415 			return false;
416 
417 		final IntRandomForest that = (IntRandomForest) r;
418 
419 		boolean same = true;
420 
421 		same &= this.numDimensions() == that.numDimensions();
422 		same &= this.getNTrees() == that.getNTrees();
423 		same &= this.getNDecisions() == that.getNDecisions();
424 
425 		for (int i = 0; i < that.trees.size(); i++) {
426 			this.trees.get(i).equals(that.trees.get(i));
427 		}
428 
429 		for (final Entry<Letter, Integer> a : that.letterToInt.entrySet()) {
430 			same &= that.letterToInt.get(a.getKey()).equals(this.letterToInt.get(a.getKey()));
431 		}
432 
433 		for (final Entry<Word, Integer> a : that.wordToInt.entrySet()) {
434 			same &= that.wordToInt.get(a.getKey()).equals(this.wordToInt.get(a.getKey()));
435 		}
436 
437 		return same;
438 	}
439 
440 	@Override
441 	public String asciiHeader() {
442 		return "ASCII" + HEADER;
443 	}
444 
445 	@Override
446 	public byte[] binaryHeader() {
447 		return HEADER.getBytes();
448 	}
449 
450 	@Override
451 	public void readASCII(Scanner br) throws IOException {
452 		nDecisions = Integer.parseInt(br.nextLine());
453 		nTrees = Integer.parseInt(br.nextLine());
454 		this.letterToInt = new HashMap<Letter, Integer>();
455 		featureLength = Integer.parseInt(br.nextLine());
456 
457 		if (this.trees == null || this.trees.size() != nTrees) {
458 			trees = new LinkedList<RandomDecisionTree>();
459 			for (int i = 0; i < nTrees; i++)
460 				trees.add(new RandomDecisionTree().readASCII(br));
461 		} else {
462 			// We have an existing tree, try to read it!
463 			for (final RandomDecisionTree rt : trees) {
464 				rt.readASCII(br);
465 			}
466 		}
467 
468 		// Only rebuild an array of the wrong size
469 		String[] line = br.nextLine().split(" ");
470 		if (maxVal == null || maxVal.length != featureLength)
471 			maxVal = new int[featureLength];
472 		for (int i = 0; i < featureLength; i++)
473 			maxVal[i] = Integer.parseInt(line[i]);
474 		if (minVal == null || minVal.length != featureLength)
475 			minVal = new int[featureLength];
476 		line = br.nextLine().split(" ");
477 		for (int i = 0; i < featureLength; i++)
478 			minVal[i] = Integer.parseInt(line[i]);
479 		currentInt = Integer.parseInt(br.nextLine());
480 
481 		line = br.nextLine().split(" ");
482 		assert ((line.length - 1) == currentInt);
483 		for (int i = 0; i < currentInt; i++) {
484 			final String[] part = line[i].split(",");
485 
486 			letterToInt.put(letterFromString(part[0]), Integer.parseInt(part[1]));
487 		}
488 		currentWordInt = Integer.parseInt(br.nextLine());
489 		if (currentWordInt != 0) {
490 			line = br.nextLine().split(" ");
491 			assert ((line.length - 1) == currentWordInt);
492 			for (int i = 0; i < currentWordInt; i++) {
493 				final String[] part = line[i].split(",");
494 
495 				wordToInt.put(wordFromString(part[0]), Integer.parseInt(part[1]));
496 			}
497 		}
498 	}
499 
500 	@Override
501 	public void readBinary(DataInput dis) throws IOException {
502 		nDecisions = dis.readInt();
503 		nTrees = dis.readInt();
504 		this.letterToInt = new HashMap<Letter, Integer>();
505 		featureLength = dis.readInt();
506 		if (this.trees == null || this.trees.size() != nTrees) {
507 			trees = new LinkedList<RandomDecisionTree>();
508 			for (int i = 0; i < nTrees; i++)
509 				trees.add(new RandomDecisionTree().readBinary(dis));
510 		} else {
511 			// We have an existing tree, try to read it!
512 			for (final RandomDecisionTree rt : trees) {
513 				rt.readBinary(dis);
514 			}
515 		}
516 
517 		if (maxVal == null || maxVal.length != featureLength)
518 			maxVal = new int[featureLength];
519 		for (int i = 0; i < featureLength; i++)
520 			maxVal[i] = dis.readInt();
521 
522 		if (minVal == null || minVal.length != featureLength)
523 			minVal = new int[featureLength];
524 		for (int i = 0; i < featureLength; i++)
525 			minVal[i] = dis.readInt();
526 
527 		currentInt = dis.readInt();
528 		for (int i = 0; i < currentInt; i++) {
529 			final int letterLen = dis.readInt();
530 			final boolean[] stringBytes = new boolean[letterLen];
531 			// dis.read(stringBytes, 0, stringBytes .length); // Entry key
532 			for (int j = 0; j < letterLen; j++)
533 				stringBytes[j] = dis.readBoolean();
534 			letterToInt.put(new Letter(stringBytes, dis.readInt()), dis.readInt());
535 		}
536 
537 		currentWordInt = dis.readInt();
538 		if (currentWordInt != 0) {
539 			for (int i = 0; i < currentWordInt; i++) {
540 				final Letter[] letters = new Letter[dis.readInt()];
541 				// dis.read(stringBytes, 0, stringBytes .length); // Entry key
542 				for (int j = 0; j < letters.length; j++) {
543 					final int letterLen = dis.readInt();
544 					final boolean[] stringBytes = new boolean[letterLen];
545 					for (int k = 0; k < stringBytes.length; k++) {
546 						stringBytes[k] = dis.readBoolean();
547 					}
548 					letters[j] = new Letter(stringBytes, dis.readInt());
549 				}
550 				wordToInt.put(new Word(letters), dis.readInt());
551 			}
552 		}
553 	}
554 
555 	@Override
556 	public void writeASCII(PrintWriter writer) throws IOException {
557 		writer.println(this.nDecisions);
558 		writer.println(this.nTrees);
559 		writer.println(this.featureLength);
560 		for (final RandomDecisionTree tree : trees) {
561 			tree.writeASCII(writer);
562 			writer.println();
563 		}
564 		for (int i = 0; i < maxVal.length; i++)
565 			writer.print(maxVal[i] + " ");
566 		writer.println();
567 		for (int i = 0; i < minVal.length; i++)
568 			writer.print(minVal[i] + " ");
569 		writer.println();
570 		writer.println(currentInt);
571 		for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
572 			writer.print(p.getKey() + "," + p.getValue() + " ");
573 		}
574 		writer.println();
575 		writer.println(currentWordInt);
576 		for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
577 			writer.print(p.getKey() + "," + p.getValue() + " ");
578 		}
579 		writer.println();
580 		writer.flush();
581 	}
582 
583 	@Override
584 	public void writeBinary(DataOutput o) throws IOException {
585 		o.writeInt(nDecisions);
586 		o.writeInt(nTrees);
587 		o.writeInt(featureLength);
588 		for (final RandomDecisionTree tree : trees) {
589 			tree.write(o);
590 		}
591 
592 		for (int i = 0; i < maxVal.length; i++)
593 			o.writeInt(maxVal[i]);
594 		for (int i = 0; i < minVal.length; i++)
595 			o.writeInt(minVal[i]);
596 
597 		o.writeInt(currentInt);
598 		for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
599 			o.writeInt(p.getKey().value.length);
600 			for (final boolean i : p.getKey().value)
601 				o.writeBoolean(i);
602 			o.writeInt(p.getKey().treeIndex);
603 			o.writeInt(p.getValue());
604 		}
605 		o.writeInt(currentWordInt);
606 		for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
607 			o.writeInt(p.getKey().letters.length);
608 			for (final Letter i : p.getKey().letters) {
609 				o.writeInt(i.value.length);
610 				for (final boolean j : i.value)
611 					o.writeBoolean(j);
612 				o.writeInt(i.treeIndex);
613 			}
614 			o.writeInt(p.getValue());
615 		}
616 	}
617 
618 	/**
619 	 * @param random
620 	 *            the seed of the java {@link Random} instance used by the
621 	 *            decision trees
622 	 */
623 	public void setRandomSeed(int random) {
624 		this.randomSeed = random;
625 	}
626 
627 	Letter newLetter(boolean[] bs, int i) {
628 		return new Letter(bs, i);
629 	}
630 
631 	Word newWord(Letter[] letters) {
632 		return new Word(letters);
633 	}
634 
635 	@Override
636 	public void assignDistance(int[][] data, int[] indices, float[] distances) {
637 		throw new UnsupportedOperationException("Not implemented");
638 	}
639 
640 	@Override
641 	public IntFloatPair assignDistance(int[] data) {
642 		throw new UnsupportedOperationException("Not implemented");
643 	}
644 
645 	@Override
646 	public HardAssigner<int[], ?, ?> defaultHardAssigner() {
647 		return this;
648 	}
649 
650 	@Override
651 	public int size() {
652 		return this.currentInt;
653 	}
654 
655 	@Override
656 	public int[][] performClustering(int[][] data) {
657 		return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters();
658 	}
659 }