001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/#T#NearestNeighboursExact.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034package org.openimaj.knn; 035 036import java.util.ArrayList; 037import java.util.List; 038 039import org.openimaj.feature.ShortFVComparison; 040import org.openimaj.feature.ShortFVComparator; 041import org.openimaj.util.pair.IntFloatPair; 042import org.openimaj.util.queue.BoundedPriorityQueue; 043 044/** 045 * Exact (brute-force) k-nearest-neighbour implementation. 046 * 047 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 048 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 049 */ 050public class ShortNearestNeighboursExact extends ShortNearestNeighbours { 051 /** 052 * {@link NearestNeighboursFactory} for producing 053 * {@link ShortNearestNeighboursExact}s. 054 * 055 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 056 */ 057 public static final class Factory implements NearestNeighboursFactory<ShortNearestNeighboursExact, short[]> { 058 private final ShortFVComparator distance; 059 060 /** 061 * Construct the factory using Euclidean distance for the 062 * produced ShortNearestNeighbours instances. 063 */ 064 public Factory() { 065 this.distance = null; 066 } 067 068 /** 069 * Construct the factory with the given distance function 070 * for the produced ShortNearestNeighbours instances. 071 * 072 * @param distance 073 * the distance function 074 */ 075 public Factory(ShortFVComparator distance) { 076 this.distance = distance; 077 } 078 079 @Override 080 public ShortNearestNeighboursExact create(short[][] data) { 081 return new ShortNearestNeighboursExact(data, distance); 082 } 083 } 084 085 protected final short[][] pnts; 086 protected final ShortFVComparator distance; 087 088 /** 089 * Construct the ShortNearestNeighboursExact over the provided 090 * dataset and using Euclidean distance. 091 * @param pnts the dataset 092 */ 093 public ShortNearestNeighboursExact(final short [][] pnts) { 094 this(pnts, null); 095 } 096 097 /** 098 * Construct the ShortNearestNeighboursExact over the provided 099 * dataset with the given distance function. 100 * <p> 101 * Note: If the distance function provides similarities rather 102 * than distances they are automatically inverted. 103 * 104 * @param pnts the dataset 105 * @param distance the distance function 106 */ 107 public ShortNearestNeighboursExact(final short [][] pnts, final ShortFVComparator distance) { 108 this.pnts = pnts; 109 this.distance = distance; 110 } 111 112 @Override 113 public void searchNN(final short [][] qus, int [] indices, float [] distances) { 114 final int N = qus.length; 115 116 final BoundedPriorityQueue<IntFloatPair> queue = 117 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 118 119 //prepare working data 120 List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 121 list.add(new IntFloatPair()); 122 list.add(new IntFloatPair()); 123 124 for (int n=0; n < N; ++n) { 125 List<IntFloatPair> result = search(qus[n], queue, list); 126 127 final IntFloatPair p = result.get(0); 128 indices[n] = p.first; 129 distances[n] = p.second; 130 } 131 } 132 133 @Override 134 public void searchKNN(final short [][] qus, int K, int [][] indices, float [][] distances) { 135 // Fix for when the user asks for too many points. 136 K = Math.min(K, pnts.length); 137 138 final int N = qus.length; 139 140 final BoundedPriorityQueue<IntFloatPair> queue = 141 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 142 143 //prepare working data 144 List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 145 for (int i = 0; i < K + 1; i++) { 146 list.add(new IntFloatPair()); 147 } 148 149 // search on each query 150 for (int n = 0; n < N; ++n) { 151 List<IntFloatPair> result = search(qus[n], queue, list); 152 153 for (int k = 0; k < K; ++k) { 154 final IntFloatPair p = result.get(k); 155 indices[n][k] = p.first; 156 distances[n][k] = p.second; 157 } 158 } 159 } 160 161 @Override 162 public void searchNN(final List<short[]> qus, int [] indices, float [] distances) { 163 final int N = qus.size(); 164 165 final BoundedPriorityQueue<IntFloatPair> queue = 166 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 167 168 //prepare working data 169 List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 170 list.add(new IntFloatPair()); 171 list.add(new IntFloatPair()); 172 173 for (int n=0; n < N; ++n) { 174 List<IntFloatPair> result = search(qus.get(n), queue, list); 175 176 final IntFloatPair p = result.get(0); 177 indices[n] = p.first; 178 distances[n] = p.second; 179 } 180 } 181 182 @Override 183 public void searchKNN(final List<short[]> qus, int K, int [][] indices, float [][] distances) { 184 // Fix for when the user asks for too many points. 185 K = Math.min(K, pnts.length); 186 187 final int N = qus.size(); 188 189 final BoundedPriorityQueue<IntFloatPair> queue = 190 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 191 192 //prepare working data 193 List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 194 for (int i = 0; i < K + 1; i++) { 195 list.add(new IntFloatPair()); 196 } 197 198 // search on each query 199 for (int n = 0; n < N; ++n) { 200 List<IntFloatPair> result = search(qus.get(n), queue, list); 201 202 for (int k = 0; k < K; ++k) { 203 final IntFloatPair p = result.get(k); 204 indices[n][k] = p.first; 205 distances[n][k] = p.second; 206 } 207 } 208 } 209 210 @Override 211 public List<IntFloatPair> searchKNN(short[] query, int K) { 212 // Fix for when the user asks for too many points. 213 K = Math.min(K, pnts.length); 214 215 final BoundedPriorityQueue<IntFloatPair> queue = 216 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 217 218 //prepare working data 219 List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 220 for (int i = 0; i < K + 1; i++) { 221 list.add(new IntFloatPair()); 222 } 223 224 // search 225 return search(query, queue, list); 226 } 227 228 @Override 229 public IntFloatPair searchNN(final short[] query) { 230 final BoundedPriorityQueue<IntFloatPair> queue = 231 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 232 233 //prepare working data 234 List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 235 list.add(new IntFloatPair()); 236 list.add(new IntFloatPair()); 237 238 return search(query, queue, list).get(0); 239 } 240 241 private List<IntFloatPair> search(short[] query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results) { 242 IntFloatPair wp = null; 243 244 // reset all values in the queue to MAX, -1 245 for (final IntFloatPair p : results) { 246 p.second = Float.MAX_VALUE; 247 p.first = -1; 248 wp = queue.offerItem(p); 249 } 250 251 // perform the search 252 for (int i = 0; i < this.pnts.length; i++) { 253 wp.second = distanceFunc(distance, query, pnts[i]); 254 wp.first = i; 255 wp = queue.offerItem(wp); 256 } 257 258 return queue.toOrderedListDestructive(); 259 } 260 261 @Override 262 public int numDimensions() { 263 return pnts[0].length; 264 } 265 266 @Override 267 public int size() { 268 return pnts.length; 269 } 270 271 /** 272 * Get the underlying data points. 273 * 274 * @return the data points 275 */ 276 public short[][] getPoints() { 277 return this.pnts; 278 } 279 280 /** 281 * Compute the distance between two vectors using the underlying distance 282 * comparison used by this class. 283 * 284 * @param a 285 * the first vector 286 * @param b 287 * the second vector 288 * @return the distance between the two vectors 289 */ 290 public float computeDistance(short[] a, short[] b) { 291 if (distance == null) 292 return (float) ShortFVComparison.SUM_SQUARE.compare(a, b); 293 return (float) distance.compare(a, b); 294 } 295 296 /** 297 * Get the distance comparator 298 * 299 * @return the distance comparator 300 */ 301 public ShortFVComparator distanceComparator() { 302 return this.distance; 303 } 304}