001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.rac; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035import java.io.PrintWriter; 036import java.util.ArrayList; 037import java.util.List; 038import java.util.Scanner; 039 040import org.apache.commons.math.FunctionEvaluationException; 041import org.apache.commons.math.MaxIterationsExceededException; 042import org.apache.commons.math.analysis.UnivariateRealFunction; 043import org.apache.commons.math.analysis.solvers.BisectionSolver; 044import org.openimaj.citation.annotation.Reference; 045import org.openimaj.citation.annotation.ReferenceType; 046import org.openimaj.data.DataSource; 047import org.openimaj.data.RandomData; 048import org.openimaj.ml.clustering.CentroidsProvider; 049import org.openimaj.ml.clustering.IndexClusters; 050import org.openimaj.ml.clustering.SpatialClusterer; 051import org.openimaj.ml.clustering.SpatialClusters; 052import org.openimaj.ml.clustering.assignment.HardAssigner; 053import org.openimaj.util.pair.IntFloatPair; 054 055/** 056 * An implementation of the RAC algorithm proposed by <a 057 * href="http://eprints.ecs.soton.ac.uk/21401/">Ramanan and Niranjan</a>. 058 * <p> 059 * During training, data points are selected at random. The first data point is 060 * chosen as a centroid. Every following data point is set as a new centroid if 061 * it is outside the threshold of all current centroids. In this way it is 062 * difficult to guarantee number of clusters so a minimisation function is 063 * provided to allow a close estimate of the required threshold for a given K. 064 * <p> 065 * This implementation supports int[] cluster centroids. 066 * <p> 067 * In terms of implementation, this class is a both a clusterer, assigner and 068 * the result of the clustering. This is because the RAC algorithm never ends; 069 * that is to say that if a new point is being assigned through the 070 * {@link HardAssigner} interface, and that point is more than the threshold 071 * distance from any other centroid, then a new centroid will be created for the 072 * point. If this behaviour is undesirable, the results of clustering can be 073 * "frozen" by manually constructing an assigner that takes a 074 * {@link CentroidsProvider} (or the centroids provided by calling 075 * {@link #getCentroids()}) as an argument. 076 * 077 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 078 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 079 */ 080@Reference( 081 type = ReferenceType.Inproceedings, 082 author = { "Amirthalingam Ramanan", "Mahesan Niranjan" }, 083 title = "Resource-Allocating Codebook for Patch-based Face Recognition", 084 year = "2009", 085 booktitle = "IIS", 086 url = "http://eprints.ecs.soton.ac.uk/21401/") 087public class IntRAC 088 implements 089 SpatialClusters<int[]>, 090 SpatialClusterer<IntRAC, int[]>, 091 CentroidsProvider<int[]>, 092 HardAssigner<int[], float[], IntFloatPair> 093{ 094 private static class ClusterMinimisationFunction implements UnivariateRealFunction { 095 private int[][] distances; 096 private int[][] samples; 097 private int nClusters; 098 099 public ClusterMinimisationFunction(int[][] samples, int[][] distances, int nClusters) { 100 this.distances = distances; 101 this.samples = samples; 102 this.nClusters = nClusters; 103 } 104 105 @Override 106 public double value(double radius) throws FunctionEvaluationException { 107 final IntRAC r = new IntRAC(radius); 108 r.train(samples, distances); 109 final int diff = this.nClusters - r.numClusters(); 110 return diff; 111 } 112 } 113 114 private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RAIC"; 115 116 protected ArrayList<int[]> codebook; 117 protected double threshold; 118 protected int nDims; 119 protected static int[][] distances; 120 protected long totalSamples; 121 122 /** 123 * Sets the threshold to 128 124 */ 125 public IntRAC() { 126 codebook = new ArrayList<int[]>(); 127 this.threshold = 128; 128 this.nDims = -1; 129 this.totalSamples = 0; 130 } 131 132 /** 133 * Define the threshold at which point a new cluster will be made. 134 * 135 * @param radiusSquared 136 */ 137 public IntRAC(double radiusSquared) { 138 this(); 139 this.threshold = radiusSquared; 140 } 141 142 /** 143 * Iteratively select subSamples from bKeys and try to choose a threshold 144 * which results in nClusters. This is provided to estimate threshold as 145 * this is a very data dependant value. The threshold is found using a 146 * BisectionSolver with a complete distance matrix (so make sure subSamples 147 * is reasonable) 148 * 149 * @param bKeys 150 * All keys to be trained against 151 * @param subSamples 152 * number of subsamples to select from bKeys each iteration 153 * @param nClusters 154 * number of clusters to aim for 155 */ 156 public IntRAC(int[][] bKeys, int subSamples, int nClusters) { 157 this(); 158 159 distances = new int[subSamples][subSamples]; 160 int j = 0; 161 this.threshold = 0; 162 final int thresholdIteration = 5; 163 while (j++ < thresholdIteration) { 164 final int[][] randomList = new int[subSamples][]; 165 final int[] randomListIndex = RandomData.getUniqueRandomInts(subSamples, 0, bKeys.length); 166 int ri = 0; 167 for (int k = 0; k < randomListIndex.length; k++) 168 randomList[ri++] = bKeys[randomListIndex[k]]; 169 try { 170 this.threshold += calculateThreshold(randomList, nClusters); 171 } catch (final Exception e) { 172 this.threshold += 200000; 173 } 174 System.out.println("Current threshold: " + this.threshold / j); 175 } 176 this.threshold /= thresholdIteration; 177 } 178 179 @SuppressWarnings("deprecation") 180 protected static double calculateThreshold(int[][] samples, int nClusters) throws MaxIterationsExceededException, 181 FunctionEvaluationException 182 { 183 int maxDistance = 0; 184 for (int i = 0; i < samples.length; i++) { 185 for (int j = i + 1; j < samples.length; j++) { 186 distances[i][j] = distanceEuclidianSquared(samples[i], samples[j]); 187 distances[j][i] = distances[i][j]; 188 if (distances[i][j] > maxDistance) 189 maxDistance = distances[i][j]; 190 } 191 } 192 System.out.println("Distance matrix calculated"); 193 final BisectionSolver b = new BisectionSolver(); 194 b.setAbsoluteAccuracy(100.0); 195 return b.solve(100, new ClusterMinimisationFunction(samples, distances, nClusters), 0, maxDistance); 196 } 197 198 int train(int[][] samples, int[][] distances) { 199 int foundLength = -1; 200 final List<Integer> codebookIndex = new ArrayList<Integer>(); 201 for (int i = 0; i < samples.length; i++) { 202 final int[] entry = samples[i]; 203 if (foundLength == -1) 204 foundLength = entry.length; 205 206 // all the data entries must be the same length otherwise this 207 // doesn't make sense 208 if (foundLength != entry.length) { 209 this.codebook = new ArrayList<int[]>(); 210 return -1; 211 } 212 boolean found = false; 213 for (final int j : codebookIndex) { 214 if (distances[i][j] < threshold) { 215 found = true; 216 break; 217 } 218 } 219 if (!found) { 220 this.codebook.add(entry); 221 codebookIndex.add(i); 222 } 223 } 224 this.nDims = foundLength; 225 return 0; 226 } 227 228 @Override 229 public IntRAC cluster(int[][] data) { 230 int foundLength = -1; 231 232 for (final int[] entry : data) { 233 if (foundLength == -1) 234 foundLength = entry.length; 235 236 // all the data entries must be the same length otherwise this 237 // doesn't make sense 238 if (foundLength != entry.length) { 239 this.codebook = new ArrayList<int[]>(); 240 throw new RuntimeException(); 241 } 242 boolean found = false; 243 for (final int[] existing : this.codebook) { 244 if (distanceEuclidianSquared(entry, existing) < threshold) { 245 found = true; 246 break; 247 } 248 } 249 if (!found) { 250 this.codebook.add(entry); 251 if (this.codebook.size() % 1000 == 0) { 252 System.out.println("Codebook increased to size " + this.codebook.size()); 253 } 254 } 255 } 256 257 return this; 258 } 259 260 @Override 261 public IntRAC cluster(DataSource<int[]> data) { 262 final int[][] dataArr = new int[data.size()][data.numDimensions()]; 263 264 return cluster(dataArr); 265 } 266 267 static int distanceEuclidianSquared(int[] a, int[] b) { 268 int sum = 0; 269 for (int i = 0; i < a.length; i++) { 270 final int diff = a[i] - b[i]; 271 sum += diff * diff; 272 } 273 return sum; 274 } 275 276 static int distanceEuclidianSquared(int[] a, int[] b, int threshold2) { 277 int sum = 0; 278 279 for (int i = 0; i < a.length; i++) { 280 final int diff = a[i] - b[i]; 281 sum += diff * diff; 282 if (sum > threshold2) 283 return threshold2; 284 } 285 return sum; 286 } 287 288 @Override 289 public int numClusters() { 290 return this.codebook.size(); 291 } 292 293 @Override 294 public int numDimensions() { 295 return this.nDims; 296 } 297 298 @Override 299 public int[] assign(int[][] data) { 300 final int[] centroids = new int[data.length]; 301 for (int i = 0; i < data.length; i++) { 302 final int[] entry = data[i]; 303 centroids[i] = this.assign(entry); 304 } 305 return centroids; 306 } 307 308 @Override 309 public int assign(int[] data) { 310 int mindiff = -1; 311 int centroid = -1; 312 313 for (int i = 0; i < this.numClusters(); i++) { 314 final int[] centroids = this.codebook.get(i); 315 int sum = 0; 316 boolean set = true; 317 318 for (int j = 0; j < centroids.length; j++) { 319 final int diff = centroids[j] - data[j]; 320 sum += diff * diff; 321 if (mindiff != -1 && mindiff < sum) { 322 set = false; 323 break; // Stop checking the distance if you 324 } 325 } 326 327 if (set) { 328 mindiff = sum; 329 centroid = i; 330 // if(mindiff < this.threshold){ 331 // return centroid; 332 // } 333 } 334 } 335 return centroid; 336 } 337 338 @Override 339 public String asciiHeader() { 340 return "ASCII" + HEADER; 341 } 342 343 @Override 344 public byte[] binaryHeader() { 345 return HEADER.getBytes(); 346 } 347 348 @Override 349 public void readASCII(Scanner in) throws IOException { 350 throw new UnsupportedOperationException("Not done!"); 351 } 352 353 @Override 354 public void readBinary(DataInput dis) throws IOException { 355 threshold = dis.readDouble(); 356 nDims = dis.readInt(); 357 final int nClusters = dis.readInt(); 358 assert (threshold > 0); 359 codebook = new ArrayList<int[]>(); 360 for (int i = 0; i < nClusters; i++) { 361 final byte[] wang = new byte[nDims]; 362 dis.readFully(wang, 0, nDims); 363 final int[] cluster = new int[nDims]; 364 for (int j = 0; j < nDims; j++) 365 cluster[j] = wang[j] & 0xFF; 366 codebook.add(cluster); 367 } 368 } 369 370 @Override 371 public void writeASCII(PrintWriter writer) throws IOException { 372 writer.format("%d\n", this.threshold); 373 writer.format("%d\n", this.nDims); 374 writer.format("%d\n", this.numClusters()); 375 for (final int[] a : this.codebook) { 376 writer.format("%d,", a); 377 } 378 } 379 380 @Override 381 public void writeBinary(DataOutput dos) throws IOException { 382 dos.writeDouble(this.threshold); 383 dos.writeInt(this.nDims); 384 dos.writeInt(this.numClusters()); 385 for (final int[] arr : this.codebook) { 386 for (final int a : arr) { 387 dos.write(a); 388 } 389 } 390 } 391 392 @Override 393 public int[][] getCentroids() { 394 return this.codebook.toArray(new int[0][]); 395 } 396 397 @Override 398 public void assignDistance(int[][] data, int[] indices, float[] distances) { 399 throw new UnsupportedOperationException("Not implemented"); 400 } 401 402 @Override 403 public IntFloatPair assignDistance(int[] data) { 404 throw new UnsupportedOperationException("Not implemented"); 405 } 406 407 @Override 408 public HardAssigner<int[], ?, ?> defaultHardAssigner() { 409 return this; 410 } 411 412 /** 413 * The number of centroids; this potentially grows as assignments are made. 414 * 415 * @see org.openimaj.ml.clustering.assignment.HardAssigner#size() 416 */ 417 @Override 418 public int size() { 419 return this.nDims; 420 } 421 422 @Override 423 public int[][] performClustering(int[][] data) { 424 return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters(); 425 } 426}