001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.rforest;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.util.LinkedList;
037import java.util.List;
038import java.util.Random;
039import java.util.Scanner;
040
041/**
042 * A tree of {@link RandomDecision} nodes used for constructing a string of bits which represent a cluster
043 * point for a single data point
044 * 
045 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
046 *
047 */
048public class RandomDecisionTree {
049        List<RandomDecision> decisions;
050        private Random random = new Random();
051        /**
052         * Construct a new RandomDecisionTree setting the number of decisions and the values needed
053         * to choose a random index and min/max values for each feature vector index.
054         * 
055         * @param nDecisions
056         * @param featureLength
057         * @param minVal
058         * @param maxVal
059         */
060        public RandomDecisionTree(int nDecisions, int featureLength, int[] minVal, int[] maxVal) {
061                initDecisions(nDecisions,featureLength,minVal,maxVal);
062        }
063        
064        /**
065         * Construct a new RandomDecisionTree setting the number of decisions and the values needed
066         * to choose a random index and min/max values for each feature vector index.
067         * 
068         * @param nDecisions
069         * @param featureLength
070         * @param minVal
071         * @param maxVal
072         * @param r
073         */
074        public RandomDecisionTree(int nDecisions, int featureLength, int[] minVal, int[] maxVal,Random r) {
075                this.random = r;
076                initDecisions(nDecisions,featureLength,minVal,maxVal);
077        }
078        
079        private void initDecisions(int nDecisions, int featureLength, int[] minVal,int[] maxVal) {
080                decisions = new LinkedList<RandomDecision>();
081                for(int i = 0; i < nDecisions; i++){
082                        RandomDecision dec = new RandomDecision(featureLength,minVal,maxVal,this.random);
083                        decisions.add(dec);
084                }
085        }
086
087        /**
088         * A convenience function allowing the RandomDecisionTree to be written and read.
089         */
090        public RandomDecisionTree() {
091                decisions = new LinkedList<RandomDecision>();
092        }
093
094        /**
095         * The function which finds the path down this random tree for a given feature. Tests each
096         * required feature vector index against the threshold and assigns booleans.
097         * 
098         * @param feature
099         * @return return the letter as a string of bytes
100         */
101        public boolean[] getLetter(int[] feature){
102                boolean[] out = new boolean[decisions.size()];
103                int i = 0;
104                for(RandomDecision r : decisions){
105                        if(feature[r.feature] > r.threshold)
106                                out[i] = true;
107                        else
108                                out[i] = false;
109                        i++;
110                }
111                return out;
112        }
113
114        /**
115         * Read/Write RandomDecisionTree (including decision nodes)
116         * @param o
117         * @throws IOException
118         */
119        public void write(DataOutput o) throws IOException {
120                o.writeInt(this.decisions.size());
121                for(RandomDecision r : this.decisions){
122                        r.write(o);
123                }
124        }
125
126        /**
127         * Read/Write RandomDecisionTree (including decision nodes)
128         * @param writer
129         */
130        public void writeASCII(PrintWriter writer) {
131                for(RandomDecision r : this.decisions){
132                        r.writeASCII(writer);
133                        writer.print(" ");
134                }
135        }
136
137        /**
138         * Read/Write RandomDecisionTree (including decision nodes)
139         * @param dis
140         * @throws IOException
141         * @return this
142         */
143        public RandomDecisionTree readBinary(DataInput dis) throws IOException {
144                int nDecisions = dis.readInt();
145                if(this.decisions.size() != nDecisions){
146                        this.decisions = new LinkedList<RandomDecision>();
147                        for(int i = 0 ; i < nDecisions; i ++){
148                                RandomDecision r = new RandomDecision().readBinary(dis);
149                                this.decisions.add(r);
150                        }
151                }
152                else{
153                        for(RandomDecision rd : this.decisions){
154                                rd.readBinary(dis);
155                        }
156                }
157                return this;
158        }
159
160        /**
161         * Read/Write RandomDecisionTree (including decision nodes)
162         * @param br
163         * @throws IOException
164         * @return this
165         */
166        public RandomDecisionTree readASCII(Scanner br) throws IOException {
167                String[] lines = br.nextLine().split(" ");
168                if(this.decisions.size() != lines.length){
169                        this.decisions = new LinkedList<RandomDecision>();
170                        for(String line : lines){
171                                this.decisions.add(new RandomDecision().readString(line));
172                        }
173                }
174                else{
175                        int index = 0;
176                        for(RandomDecision rd : this.decisions){
177                                rd.readString(lines[index++]);
178                        }
179                }
180                
181                return this;
182        }
183        
184        @Override
185        public String toString(){
186                String s = "{";
187                for (RandomDecision r : this.decisions){
188                        s += r.toString() + ",";
189                }
190                s+="}";
191                return s;
192        }
193        
194        @Override
195        public boolean equals(Object o)
196        {
197                if(!(o instanceof RandomDecisionTree)) return false;
198                RandomDecisionTree rdt = (RandomDecisionTree) o;
199                for(int i = 0; i < this.decisions.size(); i++){
200                        RandomDecision d1 = rdt.decisions.get(i);
201                        RandomDecision d2 = this.decisions.get(i);
202                        
203                        if(d1.feature != d2.feature || d1.threshold != d2.threshold)
204                                return false;
205                }
206                return true;
207        }
208}