001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.rforest;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.util.HashMap;
037import java.util.LinkedList;
038import java.util.List;
039import java.util.Map;
040import java.util.Map.Entry;
041import java.util.Random;
042import java.util.Scanner;
043
044import org.openimaj.citation.annotation.Reference;
045import org.openimaj.citation.annotation.ReferenceType;
046import org.openimaj.data.DataSource;
047import org.openimaj.ml.clustering.IndexClusters;
048import org.openimaj.ml.clustering.SpatialClusterer;
049import org.openimaj.ml.clustering.SpatialClusters;
050import org.openimaj.ml.clustering.assignment.HardAssigner;
051import org.openimaj.util.hash.HashCodeUtil;
052import org.openimaj.util.pair.IntFloatPair;
053
054/**
055 * An implementation of the RandomForest clustering algorithm proposed by <a
056 * href
057 * ="http://users.info.unicaen.fr/~jurie/papers/moosman-nowak-jurie-pami08.pdf"
058 * >Jurie et al</a>.
059 * <p>
060 * In this implementation the training phase is used to identify the limits of
061 * the data (for which a very small subset may be provided). Once this is known
062 * N decision trees are constructed each with M decisions (see
063 * {@link RandomDecisionTree}). In the clustering phase each feature projected
064 * is assigned a letter for each decision tree.
065 * 
066 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
067 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
068 */
069@Reference(
070                type = ReferenceType.Article,
071                author = { "Frank Moosmann", "Eric Nowak", "Fr{\'e}d{\'e}ric Jurie" },
072                title = "Randomized Clustering Forests for Image Classification",
073                year = "2008",
074                journal = "IEEE PAMI",
075                url = "http://dx.doi.org/10.1109/TPAMI.2007.70822")
076public class IntRandomForest
077                implements
078                SpatialClusters<int[]>,
079                SpatialClusterer<IntRandomForest, int[]>,
080                HardAssigner<int[], float[], IntFloatPair>
081{
082        private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RFIC";
083        int nDecisions;
084        int nTrees;
085        int featureLength;
086        List<RandomDecisionTree> trees;
087        int[] maxVal;
088        int[] minVal;
089        Map<Letter, Integer> letterToInt;
090        private int currentInt = 0;
091        private HashMap<Word, Integer> wordToInt;
092        private int currentWordInt = 0;
093        private int randomSeed = -1;
094
095        private Word wordFromString(String str) {
096                final String[] values = str.split("_");
097                final Letter[] intValues = new Letter[values.length];
098                int i = 0;
099                for (final String s : values) {
100                        intValues[i++] = letterFromString(s);
101                }
102                return new Word(intValues);
103        }
104
105        private Letter letterFromString(String str) {
106                final String[] values = str.split("-");
107                final boolean[] intValues = new boolean[values.length];
108                int i = 0;
109                for (int j = 0; j < values.length - 1; j++) {
110                        intValues[i++] = Boolean.parseBoolean(values[j]);
111                }
112                return new Letter(intValues, Integer.parseInt(values[values.length - 1]));
113        }
114
115        class Word {
116                private Letter[] letters;
117
118                Word(Letter[] value) {
119                        letters = value;
120                }
121
122                @Override
123                public int hashCode() {
124                        int result = HashCodeUtil.SEED;
125                        for (final Letter l : letters) {
126                                result = HashCodeUtil.hash(result, l);
127                        }
128                        return result;
129                }
130
131                @Override
132                public boolean equals(Object obj) {
133                        if (!(obj instanceof Word))
134                                return false;
135
136                        final Word that = (Word) obj;
137                        boolean same = true;
138                        for (int i = 0; i < letters.length; i++) {
139                                same &= letters[i].equals(that.letters[i]);
140
141                                if (!same)
142                                        return false;
143                        }
144
145                        return same;
146                }
147
148                @Override
149                public String toString() {
150                        String outString = "";
151                        for (int i = 0; i < this.letters.length; i++) {
152                                outString += "_" + letters[i];
153                        }
154                        return outString.substring(1);
155                }
156
157                public int hashedWord() {
158                        if (!wordToInt.containsKey(this)) {
159                                wordToInt.put(this, currentWordInt++);
160                                if (currentWordInt == Integer.MAX_VALUE) {
161                                        System.err.println("Too many words!");
162                                        currentWordInt = 0;
163                                }
164                        }
165                        return wordToInt.get(this);
166                }
167        }
168
169        class Letter {
170                boolean[] value;
171                int treeIndex;
172
173                public Letter(boolean[] value, int treeIndex) {
174                        this.value = value;
175                        this.treeIndex = treeIndex;
176                }
177
178                @Override
179                public int hashCode() {
180                        int result = HashCodeUtil.SEED;
181                        result = HashCodeUtil.hash(result, this.value);
182                        result = HashCodeUtil.hash(result, this.treeIndex);
183                        return result;
184                }
185
186                public int getTreeIndex() {
187                        return this.treeIndex;
188                }
189
190                @Override
191                public boolean equals(Object obj) {
192                        if (!(obj instanceof Letter))
193                                return false;
194                        final Letter that = (Letter) obj;
195                        boolean same = true;
196                        for (int i = 0; i < this.value.length; i++) {
197                                same &= this.value[i] == that.value[i];
198                                if (!same)
199                                        return false;
200                        }
201                        same &= this.treeIndex == that.treeIndex;
202                        return same;
203                }
204
205                @Override
206                public String toString() {
207                        String outString = "";
208                        for (int i = 0; i < this.value.length; i++) {
209                                outString += "-" + (value[i] ? 1 : 0);
210                        }
211                        return outString.substring(1) + "-" + this.treeIndex;
212                }
213
214                public int hashedLetter() {
215                        if (!letterToInt.containsKey(this)) {
216                                letterToInt.put(this, currentInt++);
217                                if (currentInt == Integer.MAX_VALUE) {
218                                        System.err.println("... too many letters!");
219                                        currentInt = 0;
220                                }
221                        }
222                        return letterToInt.get(this);
223                }
224        }
225
226        /**
227         * Makes a default random forest with 32 trees each with 32 decisions. This
228         * results in 10^47 potential words.
229         */
230        public IntRandomForest() {
231                this(32, 32);
232        }
233
234        /**
235         * Makes a random forest with nTrees each with nDecisions. This will result
236         * in nTrees ^ nDecisions potential words
237         * 
238         * @param nTrees
239         *            number of trees
240         * @param nDecisions
241         *            number of decisions per tree
242         */
243        public IntRandomForest(int nTrees, int nDecisions) {
244                this.nTrees = nTrees;
245                this.nDecisions = nDecisions;
246                this.letterToInt = new HashMap<Letter, Integer>();
247                this.wordToInt = new HashMap<Word, Integer>();
248        }
249
250        private void initMinMax(int[][] data) {
251                final int[] min = new int[this.featureLength];
252                final int[] max = new int[this.featureLength];
253                boolean isSet = false;
254                for (int i = 0; i < data.length; i++) {
255                        for (int j = 0; j < this.featureLength; j++) {
256                                final int val = data[i][j];
257                                if (!isSet) {
258                                        min[j] = val;
259                                        max[j] = val;
260                                } else {
261                                        if (max[j] < val)
262                                                max[j] = val;
263                                        else if (min[j] > val)
264                                                min[j] = val;
265                                }
266                        }
267                        isSet = true;
268                }
269                setMinMax(min, max);
270        }
271
272        /**
273         * The maximum and minimum values for the various dimentions against which
274         * random decisions will be based.
275         * 
276         * @param min
277         * @param max
278         */
279        public void setMinMax(int[] min, int[] max) {
280                this.minVal = min;
281                this.maxVal = max;
282
283        }
284
285        private void initTrees() {
286                this.trees = new LinkedList<RandomDecisionTree>();
287                Random r = new Random();
288                if (this.randomSeed != -1)
289                        r = new Random(this.randomSeed);
290                for (int i = 0; i < nTrees; i++) {
291                        final RandomDecisionTree tree = new RandomDecisionTree(this.nDecisions, this.featureLength, this.minVal,
292                                        this.maxVal, r);
293                        this.trees.add(tree);
294                }
295        }
296
297        @Override
298        public IntRandomForest cluster(int[][] data) {
299                this.featureLength = data[0].length;
300
301                initMinMax(data);
302                initTrees();
303
304                return this;
305        }
306
307        @Override
308        public IntRandomForest cluster(DataSource<int[]> data) {
309                final int[][] dataArr = new int[data.size()][data.numDimensions()];
310
311                return cluster(dataArr);
312        }
313
314        @Override
315        public int numClusters() {
316                return this.currentInt;
317        }
318
319        @Override
320        public int numDimensions() {
321                return featureLength;
322        }
323
324        @Override
325        public int[] assign(int[][] data) {
326                final int[] proj = new int[data.length];
327
328                for (int i = 0; i < data.length; i++) {
329                        proj[i] = this.assign(data[i]);
330                }
331                return proj;
332        }
333
334        /**
335         * Push each data point provided to a set of letters, i.e. a word. Each
336         * letter represents a set of decisions made in a single decision tree.
337         * 
338         * @param data
339         * @return A word per data point
340         */
341        public Word[] assignLetters(int[][] data) {
342                final Word[] pushedLetters = new Word[data.length];
343
344                int i = 0;
345                for (final int[] k : data) {
346                        pushedLetters[i++] = this.assignWord(k);
347                }
348
349                return pushedLetters;
350        }
351
352        /**
353         * Push a single data point to a set of letters, return the letters as word.
354         * This is achieved by pushing the data point down each decision tree. This
355         * returns the result of the n decisions made for that tree, this is a
356         * single letter. As a letter is seen for the first time, it is assigned a
357         * number. The letters are combined in sequence to construct a word.
358         * 
359         * @param data
360         *            to be projected
361         * @return A single word containing the letters containing the decisions
362         *         made on each tree
363         */
364        public Word assignWord(int[] data) {
365                final Letter[] pushed = new Letter[this.nTrees];
366                for (int i = 0; i < this.nTrees; i++) {
367                        final boolean[] justLetter = this.trees.get(i).getLetter(data);
368                        final Letter letter = new Letter(justLetter, i);
369                        letter.hashedLetter();
370                        pushed[i] = letter;
371                }
372                return new Word(pushed);
373        }
374
375        /**
376         * Uses the {@link #assignWord(int[])} function to construct the word
377         * representing this data point. If this exact word has been seen before
378         * (i.e. these letters in this order) the same int is used. If not, a new
379         * int is assigned for this word.
380         * 
381         * @param data
382         *            a data point to be clustered to a word
383         * @return a cluster centroid from a word
384         */
385        @Override
386        public int assign(int[] data) {
387                final Word word = this.assignWord(data);
388                return word.hashedWord();
389        }
390
391        /**
392         * @return The number of decision trees
393         */
394        public int getNTrees() {
395                return this.nTrees;
396        }
397
398        /**
399         * @return the number of decisions per tree
400         */
401        public int getNDecisions() {
402                return this.nDecisions;
403        }
404
405        /**
406         * @return the decision trees
407         */
408        public List<RandomDecisionTree> getTrees() {
409                return this.trees;
410        }
411
412        @Override
413        public boolean equals(Object r) {
414                if (!(r instanceof IntRandomForest))
415                        return false;
416
417                final IntRandomForest that = (IntRandomForest) r;
418
419                boolean same = true;
420
421                same &= this.numDimensions() == that.numDimensions();
422                same &= this.getNTrees() == that.getNTrees();
423                same &= this.getNDecisions() == that.getNDecisions();
424
425                for (int i = 0; i < that.trees.size(); i++) {
426                        this.trees.get(i).equals(that.trees.get(i));
427                }
428
429                for (final Entry<Letter, Integer> a : that.letterToInt.entrySet()) {
430                        same &= that.letterToInt.get(a.getKey()).equals(this.letterToInt.get(a.getKey()));
431                }
432
433                for (final Entry<Word, Integer> a : that.wordToInt.entrySet()) {
434                        same &= that.wordToInt.get(a.getKey()).equals(this.wordToInt.get(a.getKey()));
435                }
436
437                return same;
438        }
439
440        @Override
441        public String asciiHeader() {
442                return "ASCII" + HEADER;
443        }
444
445        @Override
446        public byte[] binaryHeader() {
447                return HEADER.getBytes();
448        }
449
450        @Override
451        public void readASCII(Scanner br) throws IOException {
452                nDecisions = Integer.parseInt(br.nextLine());
453                nTrees = Integer.parseInt(br.nextLine());
454                this.letterToInt = new HashMap<Letter, Integer>();
455                featureLength = Integer.parseInt(br.nextLine());
456
457                if (this.trees == null || this.trees.size() != nTrees) {
458                        trees = new LinkedList<RandomDecisionTree>();
459                        for (int i = 0; i < nTrees; i++)
460                                trees.add(new RandomDecisionTree().readASCII(br));
461                } else {
462                        // We have an existing tree, try to read it!
463                        for (final RandomDecisionTree rt : trees) {
464                                rt.readASCII(br);
465                        }
466                }
467
468                // Only rebuild an array of the wrong size
469                String[] line = br.nextLine().split(" ");
470                if (maxVal == null || maxVal.length != featureLength)
471                        maxVal = new int[featureLength];
472                for (int i = 0; i < featureLength; i++)
473                        maxVal[i] = Integer.parseInt(line[i]);
474                if (minVal == null || minVal.length != featureLength)
475                        minVal = new int[featureLength];
476                line = br.nextLine().split(" ");
477                for (int i = 0; i < featureLength; i++)
478                        minVal[i] = Integer.parseInt(line[i]);
479                currentInt = Integer.parseInt(br.nextLine());
480
481                line = br.nextLine().split(" ");
482                assert ((line.length - 1) == currentInt);
483                for (int i = 0; i < currentInt; i++) {
484                        final String[] part = line[i].split(",");
485
486                        letterToInt.put(letterFromString(part[0]), Integer.parseInt(part[1]));
487                }
488                currentWordInt = Integer.parseInt(br.nextLine());
489                if (currentWordInt != 0) {
490                        line = br.nextLine().split(" ");
491                        assert ((line.length - 1) == currentWordInt);
492                        for (int i = 0; i < currentWordInt; i++) {
493                                final String[] part = line[i].split(",");
494
495                                wordToInt.put(wordFromString(part[0]), Integer.parseInt(part[1]));
496                        }
497                }
498        }
499
500        @Override
501        public void readBinary(DataInput dis) throws IOException {
502                nDecisions = dis.readInt();
503                nTrees = dis.readInt();
504                this.letterToInt = new HashMap<Letter, Integer>();
505                featureLength = dis.readInt();
506                if (this.trees == null || this.trees.size() != nTrees) {
507                        trees = new LinkedList<RandomDecisionTree>();
508                        for (int i = 0; i < nTrees; i++)
509                                trees.add(new RandomDecisionTree().readBinary(dis));
510                } else {
511                        // We have an existing tree, try to read it!
512                        for (final RandomDecisionTree rt : trees) {
513                                rt.readBinary(dis);
514                        }
515                }
516
517                if (maxVal == null || maxVal.length != featureLength)
518                        maxVal = new int[featureLength];
519                for (int i = 0; i < featureLength; i++)
520                        maxVal[i] = dis.readInt();
521
522                if (minVal == null || minVal.length != featureLength)
523                        minVal = new int[featureLength];
524                for (int i = 0; i < featureLength; i++)
525                        minVal[i] = dis.readInt();
526
527                currentInt = dis.readInt();
528                for (int i = 0; i < currentInt; i++) {
529                        final int letterLen = dis.readInt();
530                        final boolean[] stringBytes = new boolean[letterLen];
531                        // dis.read(stringBytes, 0, stringBytes .length); // Entry key
532                        for (int j = 0; j < letterLen; j++)
533                                stringBytes[j] = dis.readBoolean();
534                        letterToInt.put(new Letter(stringBytes, dis.readInt()), dis.readInt());
535                }
536
537                currentWordInt = dis.readInt();
538                if (currentWordInt != 0) {
539                        for (int i = 0; i < currentWordInt; i++) {
540                                final Letter[] letters = new Letter[dis.readInt()];
541                                // dis.read(stringBytes, 0, stringBytes .length); // Entry key
542                                for (int j = 0; j < letters.length; j++) {
543                                        final int letterLen = dis.readInt();
544                                        final boolean[] stringBytes = new boolean[letterLen];
545                                        for (int k = 0; k < stringBytes.length; k++) {
546                                                stringBytes[k] = dis.readBoolean();
547                                        }
548                                        letters[j] = new Letter(stringBytes, dis.readInt());
549                                }
550                                wordToInt.put(new Word(letters), dis.readInt());
551                        }
552                }
553        }
554
555        @Override
556        public void writeASCII(PrintWriter writer) throws IOException {
557                writer.println(this.nDecisions);
558                writer.println(this.nTrees);
559                writer.println(this.featureLength);
560                for (final RandomDecisionTree tree : trees) {
561                        tree.writeASCII(writer);
562                        writer.println();
563                }
564                for (int i = 0; i < maxVal.length; i++)
565                        writer.print(maxVal[i] + " ");
566                writer.println();
567                for (int i = 0; i < minVal.length; i++)
568                        writer.print(minVal[i] + " ");
569                writer.println();
570                writer.println(currentInt);
571                for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
572                        writer.print(p.getKey() + "," + p.getValue() + " ");
573                }
574                writer.println();
575                writer.println(currentWordInt);
576                for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
577                        writer.print(p.getKey() + "," + p.getValue() + " ");
578                }
579                writer.println();
580                writer.flush();
581        }
582
583        @Override
584        public void writeBinary(DataOutput o) throws IOException {
585                o.writeInt(nDecisions);
586                o.writeInt(nTrees);
587                o.writeInt(featureLength);
588                for (final RandomDecisionTree tree : trees) {
589                        tree.write(o);
590                }
591
592                for (int i = 0; i < maxVal.length; i++)
593                        o.writeInt(maxVal[i]);
594                for (int i = 0; i < minVal.length; i++)
595                        o.writeInt(minVal[i]);
596
597                o.writeInt(currentInt);
598                for (final Entry<Letter, Integer> p : letterToInt.entrySet()) {
599                        o.writeInt(p.getKey().value.length);
600                        for (final boolean i : p.getKey().value)
601                                o.writeBoolean(i);
602                        o.writeInt(p.getKey().treeIndex);
603                        o.writeInt(p.getValue());
604                }
605                o.writeInt(currentWordInt);
606                for (final Entry<Word, Integer> p : wordToInt.entrySet()) {
607                        o.writeInt(p.getKey().letters.length);
608                        for (final Letter i : p.getKey().letters) {
609                                o.writeInt(i.value.length);
610                                for (final boolean j : i.value)
611                                        o.writeBoolean(j);
612                                o.writeInt(i.treeIndex);
613                        }
614                        o.writeInt(p.getValue());
615                }
616        }
617
618        /**
619         * @param random
620         *            the seed of the java {@link Random} instance used by the
621         *            decision trees
622         */
623        public void setRandomSeed(int random) {
624                this.randomSeed = random;
625        }
626
627        Letter newLetter(boolean[] bs, int i) {
628                return new Letter(bs, i);
629        }
630
631        Word newWord(Letter[] letters) {
632                return new Word(letters);
633        }
634
635        @Override
636        public void assignDistance(int[][] data, int[] indices, float[] distances) {
637                throw new UnsupportedOperationException("Not implemented");
638        }
639
640        @Override
641        public IntFloatPair assignDistance(int[] data) {
642                throw new UnsupportedOperationException("Not implemented");
643        }
644
645        @Override
646        public HardAssigner<int[], ?, ?> defaultHardAssigner() {
647                return this;
648        }
649
650        @Override
651        public int size() {
652                return this.currentInt;
653        }
654
655        @Override
656        public int[][] performClustering(int[][] data) {
657                return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters();
658        }
659}