1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.clustering.rforest;
31
32 import java.io.DataInput;
33 import java.io.DataOutput;
34 import java.io.IOException;
35 import java.io.PrintWriter;
36 import java.util.HashMap;
37 import java.util.LinkedList;
38 import java.util.List;
39 import java.util.Map;
40 import java.util.Map.Entry;
41 import java.util.Random;
42 import java.util.Scanner;
43
44 import org.openimaj.citation.annotation.Reference;
45 import org.openimaj.citation.annotation.ReferenceType;
46 import org.openimaj.data.DataSource;
47 import org.openimaj.ml.clustering.IndexClusters;
48 import org.openimaj.ml.clustering.SpatialClusterer;
49 import org.openimaj.ml.clustering.SpatialClusters;
50 import org.openimaj.ml.clustering.assignment.HardAssigner;
51 import org.openimaj.util.hash.HashCodeUtil;
52 import org.openimaj.util.pair.IntFloatPair;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 @Reference(
70 type = ReferenceType.Article,
71 author = { "Frank Moosmann", "Eric Nowak", "Fr{\'e}d{\'e}ric Jurie" },
72 title = "Randomized Clustering Forests for Image Classification",
73 year = "2008",
74 journal = "IEEE PAMI",
75 url = "http://dx.doi.org/10.1109/TPAMI.2007.70822")
76 public class IntRandomForest
77 implements
78 SpatialClusters<int[]>,
79 SpatialClusterer<IntRandomForest, int[]>,
80 HardAssigner<int[], float[], IntFloatPair>
81 {
82 private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RFIC";
83 int nDecisions;
84 int nTrees;
85 int featureLength;
86 List<RandomDecisionTree> trees;
87 int[] maxVal;
88 int[] minVal;
89 Map<Letter, Integer> letterToInt;
90 private int currentInt = 0;
91 private HashMap<Word, Integer> wordToInt;
92 private int currentWordInt = 0;
93 private int randomSeed = -1;
94
95 private Word wordFromString(String str) {
96 final String[] values = str.split("_");
97 final Letter[] intValues = new Letter[values.length];
98 int i = 0;
99 for (final String s : values) {
100 intValues[i++] = letterFromString(s);
101 }
102 return new Word(intValues);
103 }
104
105 private Letter letterFromString(String str) {
106 final String[] values = str.split("-");
107 final boolean[] intValues = new boolean[values.length];
108 int i = 0;
109 for (int j = 0; j < values.length - 1; j++) {
110 intValues[i++] = Boolean.parseBoolean(values[j]);
111 }
112 return new Letter(intValues, Integer.parseInt(values[values.length - 1]));
113 }
114
115 class Word {
116 private Letter[] letters;
117
118 Word(Letter[] value) {
119 letters = value;
120 }
121
122 @Override
123 public int hashCode() {
124 int result = HashCodeUtil.SEED;
125 for (final Letter l : letters) {
126 result = HashCodeUtil.hash(result, l);
127 }
128 return result;
129 }
130
131 @Override
132 public boolean equals(Object obj) {
133 if (!(obj instanceof Word))
134 return false;
135
136 final Word that = (Word) obj;
137 boolean same = true;
138 for (int i = 0; i < letters.length; i++) {
139 same &= letters[i].equals(that.letters[i]);
140
141 if (!same)
142 return false;
143 }
144
145 return same;
146 }
147
148 @Override
149 public String toString() {
150 String outString = "";
151 for (int i = 0; i < this.letters.length; i++) {
152 outString += "_" + letters[i];
153 }
154 return outString.substring(1);
155 }
156
157 public int hashedWord() {
158 if (!wordToInt.containsKey(this)) {
159 wordToInt.put(this, currentWordInt++);
160 if (currentWordInt == Integer.MAX_VALUE) {
161 System.err.println("Too many words!");
162 currentWordInt = 0;
163 }
164 }
165 return wordToInt.get(this);
166 }
167 }
168
169 class Letter {
170 boolean[] value;
171 int treeIndex;
172
173 public Letter(boolean[] value, int treeIndex) {
174 this.value = value;
175 this.treeIndex = treeIndex;
176 }
177
178 @Override
179 public int hashCode() {
180 int result = HashCodeUtil.SEED;
181 result = HashCodeUtil.hash(result, this.value);
182 result = HashCodeUtil.hash(result, this.treeIndex);
183 return result;
184 }
185
186 public int getTreeIndex() {
187 return this.treeIndex;
188 }
189
190 @Override
191 public boolean equals(Object obj) {
192 if (!(obj instanceof Letter))
193 return false;
194 final Letter that = (Letter) obj;
195 boolean same = true;
196 for (int i = 0; i < this.value.length; i++) {
197 same &= this.value[i] == that.value[i];
198 if (!same)
199 return false;
200 }
201 same &= this.treeIndex == that.treeIndex;
202 return same;
203 }
204
205 @Override
206 public String toString() {
207 String outString = "";
208 for (int i = 0; i < this.value.length; i++) {
209 outString += "-" + (value[i] ? 1 : 0);
210 }
211 return outString.substring(1) + "-" + this.treeIndex;
212 }
213
214 public int hashedLetter() {
215 if (!letterToInt.containsKey(this)) {
216 letterToInt.put(this, currentInt++);
217 if (currentInt == Integer.MAX_VALUE) {
218 System.err.println("... too many letters!");
219 currentInt = 0;
220 }
221 }
222 return letterToInt.get(this);
223 }
224 }
225
226
227
228
229
230 public IntRandomForest() {
231 this(32, 32);
232 }
233
234
235
236
237
238
239
240
241
242
243 public IntRandomForest(int nTrees, int nDecisions) {
244 this.nTrees = nTrees;
245 this.nDecisions = nDecisions;
246 this.letterToInt = new HashMap<Letter, Integer>();
247 this.wordToInt = new HashMap<Word, Integer>();
248 }
249
250 private void initMinMax(int[][] data) {
251 final int[] min = new int[this.featureLength];
252 final int[] max = new int[this.featureLength];
253 boolean isSet = false;
254 for (int i = 0; i < data.length; i++) {
255 for (int j = 0; j < this.featureLength; j++) {
256 final int val = data[i][j];
257 if (!isSet) {
258 min[j] = val;
259 max[j] = val;
260 } else {
261 if (max[j] < val)
262 max[j] = val;
263 else if (min[j] > val)
264 min[j] = val;
265 }
266 }
267 isSet = true;
268 }
269 setMinMax(min, max);
270 }
271
272
273
274
275
276
277
278
279 public void setMinMax(int[] min, int[] max) {
280 this.minVal = min;
281 this.maxVal = max;
282
283 }
284
285 private void initTrees() {
286 this.trees = new LinkedList<RandomDecisionTree>();
287 Random r = new Random();
288 if (this.randomSeed != -1)
289 r = new Random(this.randomSeed);
290 for (int i = 0; i < nTrees; i++) {
291 final RandomDecisionTree tree = new RandomDecisionTree(this.nDecisions, this.featureLength, this.minVal,
292 this.maxVal, r);
293 this.trees.add(tree);
294 }
295 }
296
297 @Override
298 public IntRandomForest cluster(int[][] data) {
299 this.featureLength = data[0].length;
300
301 initMinMax(data);
302 initTrees();
303
304 return this;
305 }
306
307 @Override
308 public IntRandomForest cluster(DataSource<int[]> data) {
309 final int[][] dataArr = new int[data.size()][data.numDimensions()];
310
311 return cluster(dataArr);
312 }
313
314 @Override
315 public int numClusters() {
316 return this.currentInt;
317 }
318
319 @Override
320 public int numDimensions() {
321 return featureLength;
322 }
323
324 @Override
325 public int[] assign(int[][] data) {
326 final int[] proj = new int[data.length];
327
328 for (int i = 0; i < data.length; i++) {
329 proj[i] = this.assign(data[i]);
330 }
331 return proj;
332 }
333
334
335
336
337
338
339
340
341 public Word[] assignLetters(int[][] data) {
342 final Word[] pushedLetters = new Word[data.length];
343
344 int i = 0;
345 for (final int[] k : data) {
346 pushedLetters[i++] = this.assignWord(k);
347 }
348
349 return pushedLetters;
350 }
351
352
353
354
355
356
357
358
359
360
361
362
363
364 public Word assignWord(int[] data) {
365 final Letter[] pushed = new Letter[this.nTrees];
366 for (int i = 0; i < this.nTrees; i++) {
367 final boolean[] justLetter = this.trees.get(i).getLetter(data);
368 final Letter letter = new Letter(justLetter, i);
369 letter.hashedLetter();
370 pushed[i] = letter;
371 }
372 return new Word(pushed);
373 }
374
375
376
377
378
379
380
381
382
383
384
385 @Override
386 public int assign(int[] data) {
387 final Word word = this.assignWord(data);
388 return word.hashedWord();
389 }
390
391
392
393
394 public int getNTrees() {
395 return this.nTrees;
396 }
397
398
399
400
401 public int getNDecisions() {
402 return this.nDecisions;
403 }
404
405
406
407
408 public List<RandomDecisionTree> getTrees() {
409 return this.trees;
410 }
411
412 @Override
413 public boolean equals(Object r) {
414 if (!(r instanceof IntRandomForest))
415 return false;
416
417 final IntRandomForest that = (IntRandomForest) r;
418
419 boolean same = true;
420
421 same &= this.numDimensions() == that.numDimensions();
422 same &= this.getNTrees() == that.getNTrees();
423 same &= this.getNDecisions() == that.getNDecisions();
424
425 for (int i = 0; i < that.trees.size(); i++) {
426 this.trees.get(i).equals(that.trees.get(i));
427 }
428
429 for (final Entry<Letter, Integer> a : that.letterToInt.entrySet()) {
430 same &= that.letterToInt.get(a.getKey()).equals(this.letterToInt.get(a.getKey()));
431 }
432
433 for (final Entry<Word, Integer> a : that.wordToInt.entrySet()) {
434 same &= that.wordToInt.get(a.getKey()).equals(this.wordToInt.get(a.getKey()));
435 }
436
437 return same;
438 }
439
440 @Override
441 public String asciiHeader() {
442 return "ASCII" + HEADER;
443 }
444
445 @Override
446 public byte[] binaryHeader() {
447 return HEADER.getBytes();
448 }
449
450 @Override
451 public void readASCII(Scanner br) throws IOException {
452 nDecisions = Integer.parseInt(br.nextLine());
453 nTrees = Integer.parseInt(br.nextLine());
454 this.letterToInt = new HashMap<Letter, Integer>();
455 featureLength = Integer.parseInt(br.nextLine());
456
457 if (this.trees == null || this.trees.size() != nTrees) {
458 trees = new LinkedList<RandomDecisionTree>();
459 for (int i = 0; i < nTrees; i++)
460 trees.add(new RandomDecisionTree().readASCII(br));
461 } else {
462
463 for (final RandomDecisionTree rt : trees) {
464 rt.readASCII(br);
465 }
466 }
467
468
469 String[] line = br.nextLine().split(" ");
470 if (maxVal == null || maxVal.length != featureLength)
471 maxVal = new int[featureLength];
472 for (int i = 0; i < featureLength; i++)
473 maxVal[i] = Integer.parseInt(line[i]);
474 if (minVal == null || minVal.length != featureLength)
475 minVal = new int[featureLength];
476 line = br.nextLine().split(" ");
477 for (int i = 0; i < featureLength; i++)
478 minVal[i] = Integer.parseInt(line[i]);
479 currentInt = Integer.parseInt(br.nextLine());
480
481 line = br.nextLine().split(" ");
482 assert ((line.length - 1) == currentInt);
483 for (int i = 0; i < currentInt; i++) {
484 final String[] part = line[i].split(",");
485
486 letterToInt.put(letterFromString(part[0]), Integer.parseInt(part[1]));
487 }
488 currentWordInt = Integer.parseInt(br.nextLine());
489 if (currentWordInt != 0) {
490 line = br.nextLine().split(" ");
491 assert ((line.length - 1) == currentWordInt);
492 for (int i = 0; i < currentWordInt; i++) {
493 final String[] part = line[i].split(",");
494
495 wordToInt.put(wordFromString(part[0]), Integer.parseInt(part[1]));
496 }
497 }
498 }
499
500 @Override
501 public void readBinary(DataInput dis) throws IOException {
502 nDecisions = dis.readInt();
503 nTrees = dis.readInt();
504 this.letterToInt = new HashMap<Letter, Integer>();
505 featureLength = dis.readInt();
506 if (this.trees == null || this.trees.size() != nTrees) {
507 trees = new LinkedList<RandomDecisionTree>();
508 for (int i = 0; i < nTrees; i++)
509 trees.add(new RandomDecisionTree().readBinary(dis));
510 } else {
511
512 for (final RandomDecisionTree rt : trees) {
513 rt.readBinary(dis);
514 }
515 }
516
517 if (maxVal == null || maxVal.length != featureLength)
518 maxVal = new int[featureLength];
519 for (int i = 0; i < featureLength; i++)
520 maxVal[i] = dis.readInt();
521
522 if (minVal == null || minVal.length != featureLength)
523 minVal = new int[featureLength];
524 for (int i = 0; i < featureLength; i++)
525 minVal[i] = dis.readInt();
526
527 currentInt = dis.readInt();
528 for (int i = 0; i < currentInt; i++) {
529 final int letterLen = dis.readInt();
530 final boolean[] stringBytes = new boolean[letterLen];
531
532 for (int j = 0; j < letterLen; j++)
533 stringBytes[j] = dis.readBoolean();
534 letterToInt.put(new Letter(stringBytes, dis.readInt()), dis.readInt());
535 }
536
537 currentWordInt = dis.readInt();
538 if (currentWordInt != 0) {
539 for (int i = 0; i < currentWordInt; i++) {
540 final Letter[] letters = new Letter[dis.readInt()];
541
542 for (int j = 0; j < letters.length; j++) {
543 final int letterLen = dis.readInt();
544 final boolean[] stringBytes = new boolean[letterLen];
545 for (int k = 0; k < stringBytes.length; k++) {
546 stringBytes[k] = dis.readBoolean();
547 }
548 letters[j] = new Letter(stringBytes, dis.readInt());
549 }
550 wordToInt.put(new Word(letters), dis.readInt());
551 }
552 }
553 }
554
555 @Override
556 public void writeASCII(PrintWriter writer) throws IOException {
557 writer.println(this.nDecisions);
558 writer.println(this.nTrees);
559 writer.println(this.featureLength);
560 for (final RandomDecisionTree tree : trees) {
561 tree.writeASCII(writer);
562 writer.println();
563 }
564 for (int i = 0; i < maxVal.length; i++)
565 writer.print(maxVal[i] + " ");
566 writer.println();
567 for (int i = 0; i < minVal.length; i++)
568 writer.print(minVal[i] + " ");
569 writer.println();
570 writer.println(currentInt);
571 for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
572 writer.print(p.getKey() + "," + p.getValue() + " ");
573 }
574 writer.println();
575 writer.println(currentWordInt);
576 for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
577 writer.print(p.getKey() + "," + p.getValue() + " ");
578 }
579 writer.println();
580 writer.flush();
581 }
582
583 @Override
584 public void writeBinary(DataOutput o) throws IOException {
585 o.writeInt(nDecisions);
586 o.writeInt(nTrees);
587 o.writeInt(featureLength);
588 for (final RandomDecisionTree tree : trees) {
589 tree.write(o);
590 }
591
592 for (int i = 0; i < maxVal.length; i++)
593 o.writeInt(maxVal[i]);
594 for (int i = 0; i < minVal.length; i++)
595 o.writeInt(minVal[i]);
596
597 o.writeInt(currentInt);
598 for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
599 o.writeInt(p.getKey().value.length);
600 for (final boolean i : p.getKey().value)
601 o.writeBoolean(i);
602 o.writeInt(p.getKey().treeIndex);
603 o.writeInt(p.getValue());
604 }
605 o.writeInt(currentWordInt);
606 for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
607 o.writeInt(p.getKey().letters.length);
608 for (final Letter i : p.getKey().letters) {
609 o.writeInt(i.value.length);
610 for (final boolean j : i.value)
611 o.writeBoolean(j);
612 o.writeInt(i.treeIndex);
613 }
614 o.writeInt(p.getValue());
615 }
616 }
617
618
619
620
621
622
623 public void setRandomSeed(int random) {
624 this.randomSeed = random;
625 }
626
627 Letter newLetter(boolean[] bs, int i) {
628 return new Letter(bs, i);
629 }
630
631 Word newWord(Letter[] letters) {
632 return new Word(letters);
633 }
634
635 @Override
636 public void assignDistance(int[][] data, int[] indices, float[] distances) {
637 throw new UnsupportedOperationException("Not implemented");
638 }
639
640 @Override
641 public IntFloatPair assignDistance(int[] data) {
642 throw new UnsupportedOperationException("Not implemented");
643 }
644
645 @Override
646 public HardAssigner<int[], ?, ?> defaultHardAssigner() {
647 return this;
648 }
649
650 @Override
651 public int size() {
652 return this.currentInt;
653 }
654
655 @Override
656 public int[][] performClustering(int[][] data) {
657 return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters();
658 }
659 }