001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/approximate/#T#KDTreeEnsemble.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034package org.openimaj.knn.approximate; 035 036import java.util.ArrayList; 037import java.util.Comparator; 038import java.util.List; 039import java.util.PriorityQueue; 040 041import cern.jet.random.Uniform; 042import cern.jet.random.engine.MersenneTwister; 043 044import org.openimaj.knn.ShortNearestNeighbours; 045import org.openimaj.util.array.IntArrayView; 046import org.openimaj.util.pair.*; 047 048import jal.objects.BinaryPredicate; 049import jal.objects.Sorting; 050 051/** 052 * Ensemble of Best-Bin-First KDTrees for short data. 053 * 054 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 055 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 056 */ 057public class ShortKDTreeEnsemble { 058 private static final int leaf_max_points = 14; 059 private static final int varest_max_points = 128; 060 private static final int varest_max_randsz = 5; 061 062 Uniform rng; 063 064 /** 065 * An internal node of the KDTree 066 */ 067 public static class ShortKDTreeNode { 068 class NodeData {} 069 070 class InternalNodeData extends NodeData { 071 ShortKDTreeNode right; 072 float disc; 073 int disc_dim; 074 } 075 076 class LeafNodeData extends NodeData { 077 int [] indices; 078 } 079 080 /** 081 * left == null iff this node is a leaf. 082 */ 083 ShortKDTreeNode left; 084 085 NodeData node_data; 086 087 private Uniform rng; 088 089 boolean is_leaf() { 090 return left==null; 091 } 092 093 IntFloatPair choose_split(final short [][] pnts, final IntArrayView inds) { 094 int D = pnts[0].length; 095 096 // Find mean & variance of each dimension. 097 float [] sum_x = new float[D]; 098 float [] sum_xx = new float[D]; 099 100 int count = Math.min(inds.size(), varest_max_points); 101 for (int n=0; n<count; ++n) { 102 for (int d=0; d<D; ++d) { 103 sum_x[d] += pnts[inds.getFast(n)][d]; 104 sum_xx[d] += (pnts[inds.getFast(n)][d]*pnts[inds.getFast(n)][d]); 105 } 106 } 107 108 FloatIntPair[] var_dim = new FloatIntPair[D]; 109 for (int d=0; d < D; ++d) { 110 var_dim[d] = new FloatIntPair(); 111 if (count <= 1) 112 var_dim[d].first = 0; 113 else 114 var_dim[d].first = (sum_xx[d] - ((float)1/count)*sum_x[d]*sum_x[d])/(count - 1); 115 var_dim[d].second = d; 116 } 117 118 // Partial sort makes a BIG difference to the build time. 119 int nrand = Math.min(varest_max_randsz, D); 120 Sorting.partial_sort(var_dim, 0, nrand, var_dim.length, new BinaryPredicate() { 121 @Override 122 public boolean apply(Object arg0, Object arg1) { 123 FloatIntPair p1 = (FloatIntPair) arg0; 124 FloatIntPair p2 = (FloatIntPair) arg1; 125 126 if (p1.first > p2.first) return true; 127 if (p2.first > p1.first) return false; 128 return (p1.second > p2.second); 129 }}); 130 131 int randd = var_dim[rng.nextIntFromTo(0, nrand-1)].second; 132 133 return new IntFloatPair(randd, sum_x[randd]/count); 134 } 135 136 void split_points(final short [][] pnts, IntArrayView inds) { 137 IntFloatPair spl = choose_split(pnts, inds); 138 139 ((InternalNodeData)node_data).disc_dim = spl.first; 140 ((InternalNodeData)node_data).disc = spl.second; 141 142 int N = inds.size(); 143 int l = 0; 144 int r = N; 145 while (l!=r) { 146 if (pnts[inds.getFast(l)][((InternalNodeData)node_data).disc_dim] < ((InternalNodeData)node_data).disc) l++; 147 else { 148 r--; 149 int t = inds.getFast(l); 150 inds.setFast(l, inds.getFast(r)); 151 inds.setFast(r, t); 152 } 153 } 154 155 // If either partition is empty -> vectors identical! 156 if (l==0 || l==N) { l = N/2; } // The vectors are identical, so keep nlogn performance. 157 158 left = new ShortKDTreeNode(pnts, inds.subView(0, l), rng); 159 160 ((InternalNodeData)node_data).right = new ShortKDTreeNode(pnts, inds.subView(l, N), rng); 161 } 162 163 /** Construct a new node */ 164 public ShortKDTreeNode() { } 165 166 /** 167 * Construct a new node with the given data 168 * 169 * @param pnts the data for the node and its children 170 * @param inds a list of indices that point to the relevant 171 * parts of the pnts array that should be used 172 * @param rng the random number generator 173 */ 174 public ShortKDTreeNode(final short [][] pnts, IntArrayView inds, Uniform rng) { 175 this.rng = rng; 176 if (inds.size() > leaf_max_points) { // Internal node 177 node_data = new InternalNodeData(); 178 split_points(pnts, inds); 179 } 180 else { 181 node_data = new LeafNodeData(); 182 ((LeafNodeData)node_data).indices = inds.toArray(); 183 } 184 } 185 186 void search(final short [] qu, PriorityQueue<FloatObjectPair<ShortKDTreeNode>> pri_branch, List<IntFloatPair> nns, boolean[] seen, short [][] pnts, float mindsq) 187 { 188 ShortKDTreeNode cur = this; 189 ShortKDTreeNode other = null; 190 191 while (!cur.is_leaf()) { // Follow best bin first until we hit a leaf 192 float diff = qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc; 193 194 if (diff < 0) { 195 other = ((InternalNodeData)cur.node_data).right; 196 cur = cur.left; 197 } 198 else { 199 other = cur.left; 200 cur = ((InternalNodeData)cur.node_data).right; 201 } 202 203 pri_branch.add(new FloatObjectPair<ShortKDTreeNode>(mindsq + diff*diff, other)); 204 } 205 206 int [] cur_inds = ((LeafNodeData)cur.node_data).indices; 207 int ncur_inds = cur_inds.length; 208 209 int i; 210 float [] dsq = new float[1]; 211 for (i = 0; i < ncur_inds; ++i) { 212 int ci = cur_inds[i]; 213 if (!seen[ci]) { 214 ShortNearestNeighbours.distanceFunc(qu, new short[][] {pnts[ci]}, dsq); 215 216 nns.add(new IntFloatPair(ci, dsq[0])); 217 218 seen[ci] = true; 219 } 220 } 221 } 222 } 223 224 /** The tree roots */ 225 public final ShortKDTreeNode [] trees; 226 227 /** The underlying data array */ 228 public final short [][] pnts; 229 230 /** 231 * Construct a ShortKDTreeEnsemble with the provided data, 232 * using the default of 8 trees. 233 * @param pnts the data array 234 */ 235 public ShortKDTreeEnsemble(final short [][] pnts) { 236 this(pnts, 8, 42); 237 } 238 239 /** 240 * Construct a ShortKDTreeEnsemble with the provided data and 241 * number of trees. 242 * @param pnts the data array 243 * @param ntrees the number of KDTrees in the ensemble 244 */ 245 public ShortKDTreeEnsemble(final short [][] pnts, int ntrees) { 246 this(pnts, ntrees, 42); 247 } 248 249 /** 250 * Construct a ShortKDTreeEnsemble with the provided data and 251 * number of trees. 252 * @param pnts the data array 253 * @param ntrees the number of KDTrees in the ensemble 254 * @param seed the seed for the random number generator used in 255 * tree construction 256 */ 257 public ShortKDTreeEnsemble(final short [][] pnts, int ntrees, int seed) { 258 final int N = pnts.length; 259 this.pnts = pnts; 260 this.rng = new Uniform(new MersenneTwister(seed)); 261 262 // Create inds. 263 IntArrayView inds = new IntArrayView(N); 264 for (int n=0; n<N; ++n) inds.setFast(n, n); 265 266 // Create trees. 267 trees = new ShortKDTreeNode[ntrees]; 268 for (int t=0; t<ntrees; ++t) { 269 trees[t] = new ShortKDTreeNode(pnts, inds,rng); 270 } 271 } 272 273 void search(final short [] qu, int numnn, IntFloatPair[] ret_nns, int nchecks) { 274 final int N = pnts.length; 275 276 if (nchecks < numnn) nchecks = numnn; 277 if (nchecks > N) nchecks = N; 278 279 PriorityQueue<FloatObjectPair<ShortKDTreeNode>> pri_branch = new PriorityQueue<FloatObjectPair<ShortKDTreeNode>>( 280 11, 281 new Comparator<FloatObjectPair<ShortKDTreeNode>>() { 282 @Override 283 public int compare(FloatObjectPair<ShortKDTreeNode> o1, FloatObjectPair<ShortKDTreeNode> o2) { 284 if (o1.first > o2.first) return 1; 285 if (o2.first > o1.first) return -1; 286 return 0; 287 }} 288 ); 289 290 List<IntFloatPair> nns = new ArrayList<IntFloatPair>((3*nchecks)/2); 291 boolean [] seen = new boolean[N]; 292 293 // Search each tree at least once. 294 for (int t=0; t<trees.length; ++t) { 295 trees[t].search(qu, pri_branch, nns, seen, pnts, 0); 296 } 297 298 // Continue search until we've performed enough distances 299 while (nns.size() < nchecks) { 300 FloatObjectPair<ShortKDTreeNode> pr = pri_branch.poll(); 301 302 pr.second.search(qu, pri_branch, nns, seen, pnts, pr.first); 303 } 304 305 IntFloatPair [] nns_arr = nns.toArray(new IntFloatPair[nns.size()]); 306 Sorting.partial_sort(nns_arr, 0, numnn, nns_arr.length, new BinaryPredicate() { 307 @Override 308 public boolean apply(Object lhs, Object rhs) { 309 return ((IntFloatPair)lhs).second < ((IntFloatPair)rhs).second; 310 }}); 311 312 System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks)); 313 } 314}