001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/approximate/#T#KDTreeEnsemble.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034package org.openimaj.knn.approximate;
035
036import java.util.ArrayList;
037import java.util.Comparator;
038import java.util.List;
039import java.util.PriorityQueue;
040
041import cern.jet.random.Uniform;
042import cern.jet.random.engine.MersenneTwister;
043    
044import org.openimaj.knn.ByteNearestNeighbours;
045import org.openimaj.util.array.IntArrayView;
046import org.openimaj.util.pair.*;
047
048import jal.objects.BinaryPredicate;
049import jal.objects.Sorting;
050
051/**
052 * Ensemble of Best-Bin-First KDTrees for byte data.
053 * 
054 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
055 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
056 */
057public class ByteKDTreeEnsemble {
058        private static final int leaf_max_points = 14;
059        private static final int varest_max_points = 128;
060        private static final int varest_max_randsz = 5;
061        
062        Uniform rng;
063
064    /**
065         * An internal node of the KDTree
066         */     
067        public static class ByteKDTreeNode {
068                class NodeData {}
069                
070            class InternalNodeData extends NodeData {
071                ByteKDTreeNode right;
072                float disc;
073                int disc_dim;
074            }
075            
076            class LeafNodeData extends NodeData {
077                int [] indices;
078            }
079            
080            /**
081             * left == null iff this node is a leaf.
082             */
083            ByteKDTreeNode left;
084
085            NodeData node_data;
086            
087            private Uniform rng;
088            
089            boolean is_leaf() { 
090                return left==null; 
091            }
092
093            IntFloatPair choose_split(final byte [][] pnts, final IntArrayView inds) {
094                int D = pnts[0].length;
095                
096                // Find mean & variance of each dimension.
097                float [] sum_x = new float[D];
098                float [] sum_xx = new float[D];
099                
100                int count = Math.min(inds.size(), varest_max_points);
101                for (int n=0; n<count; ++n) {
102                    for (int d=0; d<D; ++d) {
103                        sum_x[d]  += pnts[inds.getFast(n)][d];
104                        sum_xx[d] += (pnts[inds.getFast(n)][d]*pnts[inds.getFast(n)][d]);
105                    }
106                }
107
108                FloatIntPair[] var_dim = new FloatIntPair[D];
109                for (int d=0; d < D; ++d) {
110                        var_dim[d] = new FloatIntPair();
111                    if (count <= 1)
112                        var_dim[d].first = 0;
113                    else
114                        var_dim[d].first = (sum_xx[d] - ((float)1/count)*sum_x[d]*sum_x[d])/(count - 1);
115                    var_dim[d].second = d;
116                }
117
118                // Partial sort makes a BIG difference to the build time.
119                int nrand = Math.min(varest_max_randsz, D);
120                Sorting.partial_sort(var_dim, 0, nrand, var_dim.length, new BinaryPredicate() {
121                                @Override
122                                public boolean apply(Object arg0, Object arg1) {
123                                        FloatIntPair p1 = (FloatIntPair) arg0;
124                                        FloatIntPair p2 = (FloatIntPair) arg1;
125                                        
126                                        if (p1.first > p2.first) return true;
127                                        if (p2.first > p1.first) return false;
128                                        return (p1.second > p2.second);
129                                }});
130                
131                int randd = var_dim[rng.nextIntFromTo(0, nrand-1)].second;
132                
133                return new IntFloatPair(randd, sum_x[randd]/count);
134            }
135
136            void split_points(final byte [][] pnts, IntArrayView inds) {
137                IntFloatPair spl = choose_split(pnts, inds);
138
139                ((InternalNodeData)node_data).disc_dim = spl.first;
140                ((InternalNodeData)node_data).disc = spl.second;
141
142                int N = inds.size();
143                int l = 0;
144                int r = N;
145                while (l!=r) {
146                  if (pnts[inds.getFast(l)][((InternalNodeData)node_data).disc_dim] < ((InternalNodeData)node_data).disc) l++;
147                  else {
148                    r--;
149                    int t = inds.getFast(l);
150                    inds.setFast(l, inds.getFast(r));
151                    inds.setFast(r, t);
152                  }
153                }
154            
155                // If either partition is empty -> vectors identical!
156                if (l==0 || l==N) { l = N/2; } // The vectors are identical, so keep nlogn performance.
157
158                left = new ByteKDTreeNode(pnts, inds.subView(0, l), rng);
159                
160                ((InternalNodeData)node_data).right = new ByteKDTreeNode(pnts, inds.subView(l, N), rng);
161            }
162
163                /** Construct a new node */
164            public ByteKDTreeNode() { }
165
166                /** 
167                 * Construct a new node with the given data
168                 *
169                 * @param pnts the data for the node and its children
170                 * @param inds a list of indices that point to the relevant
171                 *                      parts of the pnts array that should be used
172                 * @param rng the random number generator
173                 */
174            public ByteKDTreeNode(final byte [][] pnts, IntArrayView inds, Uniform rng) {
175                this.rng = rng;
176                if (inds.size() > leaf_max_points) { // Internal node
177                        node_data = new InternalNodeData();
178                    split_points(pnts, inds);
179                }
180                else {
181                        node_data = new LeafNodeData();
182                        ((LeafNodeData)node_data).indices = inds.toArray();
183                }
184            }
185
186            void search(final byte [] qu, PriorityQueue<FloatObjectPair<ByteKDTreeNode>> pri_branch, List<IntFloatPair> nns, boolean[] seen, byte [][] pnts, float mindsq)
187            {
188                ByteKDTreeNode cur = this;
189                ByteKDTreeNode other = null;
190
191                while (!cur.is_leaf()) { // Follow best bin first until we hit a leaf
192                        float diff = qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc;
193
194                    if (diff < 0) {
195                        other = ((InternalNodeData)cur.node_data).right;
196                        cur = cur.left;
197                    }
198                    else {
199                        other = cur.left;
200                        cur = ((InternalNodeData)cur.node_data).right;
201                    }
202
203                    pri_branch.add(new FloatObjectPair<ByteKDTreeNode>(mindsq + diff*diff, other));
204                }
205
206                int [] cur_inds = ((LeafNodeData)cur.node_data).indices;
207                int ncur_inds = cur_inds.length;
208                
209                int i;
210                float [] dsq = new float[1];
211                for (i = 0; i < ncur_inds; ++i) {
212                        int ci = cur_inds[i];
213                    if (!seen[ci]) {
214                        ByteNearestNeighbours.distanceFunc(qu, new byte[][] {pnts[ci]}, dsq);
215                        
216                        nns.add(new IntFloatPair(ci, dsq[0]));
217                        
218                        seen[ci] = true;
219                    }
220                }
221            }
222        }
223        
224        /** The tree roots */ 
225        public final ByteKDTreeNode [] trees;
226        
227        /** The underlying data array */
228        public final byte [][] pnts;
229    
230    /**
231     * Construct a ByteKDTreeEnsemble with the provided data,
232     * using the default of 8 trees.
233     * @param pnts the data array 
234     */
235    public ByteKDTreeEnsemble(final byte [][] pnts) {
236        this(pnts, 8, 42);
237    }
238    
239    /**
240     * Construct a ByteKDTreeEnsemble with the provided data and
241     * number of trees.
242     * @param pnts the data array 
243     * @param ntrees the number of KDTrees in the ensemble 
244     */
245    public ByteKDTreeEnsemble(final byte [][] pnts, int ntrees) {
246        this(pnts, ntrees, 42);
247    }
248    
249    /**
250     * Construct a ByteKDTreeEnsemble with the provided data and
251     * number of trees.
252     * @param pnts the data array 
253     * @param ntrees the number of KDTrees in the ensemble
254     * @param seed the seed for the random number generator used in 
255     *                  tree construction 
256     */
257    public ByteKDTreeEnsemble(final byte [][] pnts, int ntrees, int seed) {
258        final int N = pnts.length;
259        this.pnts = pnts;
260        this.rng = new Uniform(new MersenneTwister(seed));
261
262        // Create inds.
263        IntArrayView inds = new IntArrayView(N);
264        for (int n=0; n<N; ++n) inds.setFast(n, n);
265
266        // Create trees.
267        trees = new ByteKDTreeNode[ntrees];
268        for (int t=0; t<ntrees; ++t) {
269            trees[t] = new ByteKDTreeNode(pnts, inds,rng);
270        }
271    }
272
273    void search(final byte [] qu, int numnn, IntFloatPair[] ret_nns, int nchecks) {
274        final int N = pnts.length;
275        
276        if (nchecks < numnn) nchecks = numnn;
277        if (nchecks > N) nchecks = N;
278        
279        PriorityQueue<FloatObjectPair<ByteKDTreeNode>> pri_branch = new PriorityQueue<FloatObjectPair<ByteKDTreeNode>>(
280                11, 
281                new Comparator<FloatObjectPair<ByteKDTreeNode>>() {
282                        @Override
283                        public int compare(FloatObjectPair<ByteKDTreeNode> o1, FloatObjectPair<ByteKDTreeNode> o2) {
284                                if (o1.first > o2.first) return 1;
285                                if (o2.first > o1.first) return -1;
286                                return 0;
287                        }}
288        );
289
290        List<IntFloatPair> nns = new ArrayList<IntFloatPair>((3*nchecks)/2);
291        boolean [] seen = new boolean[N];
292
293        // Search each tree at least once.
294        for (int t=0; t<trees.length; ++t) {
295            trees[t].search(qu, pri_branch, nns, seen, pnts, 0);
296        }
297
298        // Continue search until we've performed enough distances
299        while (nns.size() < nchecks) {
300                FloatObjectPair<ByteKDTreeNode> pr = pri_branch.poll();
301            
302            pr.second.search(qu, pri_branch, nns, seen, pnts, pr.first);
303        }
304
305        IntFloatPair [] nns_arr = nns.toArray(new IntFloatPair[nns.size()]); 
306        Sorting.partial_sort(nns_arr, 0, numnn, nns_arr.length, new BinaryPredicate() {
307                        @Override
308                        public boolean apply(Object lhs, Object rhs) {
309                                return ((IntFloatPair)lhs).second < ((IntFloatPair)rhs).second;
310                        }});
311
312        System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks));
313    }
314}