View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.clustering.rac;
31  
32  import java.io.DataInput;
33  import java.io.DataOutput;
34  import java.io.IOException;
35  import java.io.PrintWriter;
36  import java.util.ArrayList;
37  import java.util.List;
38  import java.util.Scanner;
39  
40  import org.apache.commons.math.FunctionEvaluationException;
41  import org.apache.commons.math.MaxIterationsExceededException;
42  import org.apache.commons.math.analysis.UnivariateRealFunction;
43  import org.apache.commons.math.analysis.solvers.BisectionSolver;
44  import org.openimaj.citation.annotation.Reference;
45  import org.openimaj.citation.annotation.ReferenceType;
46  import org.openimaj.data.DataSource;
47  import org.openimaj.data.RandomData;
48  import org.openimaj.ml.clustering.CentroidsProvider;
49  import org.openimaj.ml.clustering.IndexClusters;
50  import org.openimaj.ml.clustering.SpatialClusterer;
51  import org.openimaj.ml.clustering.SpatialClusters;
52  import org.openimaj.ml.clustering.assignment.HardAssigner;
53  import org.openimaj.util.pair.IntFloatPair;
54  
55  /**
56   * An implementation of the RAC algorithm proposed by <a
57   * href="http://eprints.ecs.soton.ac.uk/21401/">Ramanan and Niranjan</a>.
58   * <p>
59   * During training, data points are selected at random. The first data point is
60   * chosen as a centroid. Every following data point is set as a new centroid if
61   * it is outside the threshold of all current centroids. In this way it is
62   * difficult to guarantee number of clusters so a minimisation function is
63   * provided to allow a close estimate of the required threshold for a given K.
64   * <p>
65   * This implementation supports int[] cluster centroids.
66   * <p>
67   * In terms of implementation, this class is a both a clusterer, assigner and
68   * the result of the clustering. This is because the RAC algorithm never ends;
69   * that is to say that if a new point is being assigned through the
70   * {@link HardAssigner} interface, and that point is more than the threshold
71   * distance from any other centroid, then a new centroid will be created for the
72   * point. If this behaviour is undesirable, the results of clustering can be
73   * "frozen" by manually constructing an assigner that takes a
74   * {@link CentroidsProvider} (or the centroids provided by calling
75   * {@link #getCentroids()}) as an argument.
76   * 
77   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
78   * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
79   */
80  @Reference(
81  		type = ReferenceType.Inproceedings,
82  		author = { "Amirthalingam Ramanan", "Mahesan Niranjan" },
83  		title = "Resource-Allocating Codebook for Patch-based Face Recognition",
84  		year = "2009",
85  		booktitle = "IIS",
86  		url = "http://eprints.ecs.soton.ac.uk/21401/")
87  public class IntRAC
88  		implements
89  		SpatialClusters<int[]>,
90  		SpatialClusterer<IntRAC, int[]>,
91  		CentroidsProvider<int[]>,
92  		HardAssigner<int[], float[], IntFloatPair>
93  {
94  	private static class ClusterMinimisationFunction implements UnivariateRealFunction {
95  		private int[][] distances;
96  		private int[][] samples;
97  		private int nClusters;
98  
99  		public ClusterMinimisationFunction(int[][] samples, int[][] distances, int nClusters) {
100 			this.distances = distances;
101 			this.samples = samples;
102 			this.nClusters = nClusters;
103 		}
104 
105 		@Override
106 		public double value(double radius) throws FunctionEvaluationException {
107 			final IntRAC r = new IntRAC(radius);
108 			r.train(samples, distances);
109 			final int diff = this.nClusters - r.numClusters();
110 			return diff;
111 		}
112 	}
113 
114 	private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RAIC";
115 
116 	protected ArrayList<int[]> codebook;
117 	protected double threshold;
118 	protected int nDims;
119 	protected static int[][] distances;
120 	protected long totalSamples;
121 
122 	/**
123 	 * Sets the threshold to 128
124 	 */
125 	public IntRAC() {
126 		codebook = new ArrayList<int[]>();
127 		this.threshold = 128;
128 		this.nDims = -1;
129 		this.totalSamples = 0;
130 	}
131 
132 	/**
133 	 * Define the threshold at which point a new cluster will be made.
134 	 * 
135 	 * @param radiusSquared
136 	 */
137 	public IntRAC(double radiusSquared) {
138 		this();
139 		this.threshold = radiusSquared;
140 	}
141 
142 	/**
143 	 * Iteratively select subSamples from bKeys and try to choose a threshold
144 	 * which results in nClusters. This is provided to estimate threshold as
145 	 * this is a very data dependant value. The threshold is found using a
146 	 * BisectionSolver with a complete distance matrix (so make sure subSamples
147 	 * is reasonable)
148 	 * 
149 	 * @param bKeys
150 	 *            All keys to be trained against
151 	 * @param subSamples
152 	 *            number of subsamples to select from bKeys each iteration
153 	 * @param nClusters
154 	 *            number of clusters to aim for
155 	 */
156 	public IntRAC(int[][] bKeys, int subSamples, int nClusters) {
157 		this();
158 
159 		distances = new int[subSamples][subSamples];
160 		int j = 0;
161 		this.threshold = 0;
162 		final int thresholdIteration = 5;
163 		while (j++ < thresholdIteration) {
164 			final int[][] randomList = new int[subSamples][];
165 			final int[] randomListIndex = RandomData.getUniqueRandomInts(subSamples, 0, bKeys.length);
166 			int ri = 0;
167 			for (int k = 0; k < randomListIndex.length; k++)
168 				randomList[ri++] = bKeys[randomListIndex[k]];
169 			try {
170 				this.threshold += calculateThreshold(randomList, nClusters);
171 			} catch (final Exception e) {
172 				this.threshold += 200000;
173 			}
174 			System.out.println("Current threshold: " + this.threshold / j);
175 		}
176 		this.threshold /= thresholdIteration;
177 	}
178 
179 	@SuppressWarnings("deprecation")
180 	protected static double calculateThreshold(int[][] samples, int nClusters) throws MaxIterationsExceededException,
181 			FunctionEvaluationException
182 	{
183 		int maxDistance = 0;
184 		for (int i = 0; i < samples.length; i++) {
185 			for (int j = i + 1; j < samples.length; j++) {
186 				distances[i][j] = distanceEuclidianSquared(samples[i], samples[j]);
187 				distances[j][i] = distances[i][j];
188 				if (distances[i][j] > maxDistance)
189 					maxDistance = distances[i][j];
190 			}
191 		}
192 		System.out.println("Distance matrix calculated");
193 		final BisectionSolver b = new BisectionSolver();
194 		b.setAbsoluteAccuracy(100.0);
195 		return b.solve(100, new ClusterMinimisationFunction(samples, distances, nClusters), 0, maxDistance);
196 	}
197 
198 	int train(int[][] samples, int[][] distances) {
199 		int foundLength = -1;
200 		final List<Integer> codebookIndex = new ArrayList<Integer>();
201 		for (int i = 0; i < samples.length; i++) {
202 			final int[] entry = samples[i];
203 			if (foundLength == -1)
204 				foundLength = entry.length;
205 
206 			// all the data entries must be the same length otherwise this
207 			// doesn't make sense
208 			if (foundLength != entry.length) {
209 				this.codebook = new ArrayList<int[]>();
210 				return -1;
211 			}
212 			boolean found = false;
213 			for (final int j : codebookIndex) {
214 				if (distances[i][j] < threshold) {
215 					found = true;
216 					break;
217 				}
218 			}
219 			if (!found) {
220 				this.codebook.add(entry);
221 				codebookIndex.add(i);
222 			}
223 		}
224 		this.nDims = foundLength;
225 		return 0;
226 	}
227 
228 	@Override
229 	public IntRAC cluster(int[][] data) {
230 		int foundLength = -1;
231 
232 		for (final int[] entry : data) {
233 			if (foundLength == -1)
234 				foundLength = entry.length;
235 
236 			// all the data entries must be the same length otherwise this
237 			// doesn't make sense
238 			if (foundLength != entry.length) {
239 				this.codebook = new ArrayList<int[]>();
240 				throw new RuntimeException();
241 			}
242 			boolean found = false;
243 			for (final int[] existing : this.codebook) {
244 				if (distanceEuclidianSquared(entry, existing) < threshold) {
245 					found = true;
246 					break;
247 				}
248 			}
249 			if (!found) {
250 				this.codebook.add(entry);
251 				if (this.codebook.size() % 1000 == 0) {
252 					System.out.println("Codebook increased to size " + this.codebook.size());
253 				}
254 			}
255 		}
256 
257 		return this;
258 	}
259 
260 	@Override
261 	public IntRAC cluster(DataSource<int[]> data) {
262 		final int[][] dataArr = new int[data.size()][data.numDimensions()];
263 
264 		return cluster(dataArr);
265 	}
266 
267 	static int distanceEuclidianSquared(int[] a, int[] b) {
268 		int sum = 0;
269 		for (int i = 0; i < a.length; i++) {
270 			final int diff = a[i] - b[i];
271 			sum += diff * diff;
272 		}
273 		return sum;
274 	}
275 
276 	static int distanceEuclidianSquared(int[] a, int[] b, int threshold2) {
277 		int sum = 0;
278 
279 		for (int i = 0; i < a.length; i++) {
280 			final int diff = a[i] - b[i];
281 			sum += diff * diff;
282 			if (sum > threshold2)
283 				return threshold2;
284 		}
285 		return sum;
286 	}
287 
288 	@Override
289 	public int numClusters() {
290 		return this.codebook.size();
291 	}
292 
293 	@Override
294 	public int numDimensions() {
295 		return this.nDims;
296 	}
297 
298 	@Override
299 	public int[] assign(int[][] data) {
300 		final int[] centroids = new int[data.length];
301 		for (int i = 0; i < data.length; i++) {
302 			final int[] entry = data[i];
303 			centroids[i] = this.assign(entry);
304 		}
305 		return centroids;
306 	}
307 
308 	@Override
309 	public int assign(int[] data) {
310 		int mindiff = -1;
311 		int centroid = -1;
312 
313 		for (int i = 0; i < this.numClusters(); i++) {
314 			final int[] centroids = this.codebook.get(i);
315 			int sum = 0;
316 			boolean set = true;
317 
318 			for (int j = 0; j < centroids.length; j++) {
319 				final int diff = centroids[j] - data[j];
320 				sum += diff * diff;
321 				if (mindiff != -1 && mindiff < sum) {
322 					set = false;
323 					break; // Stop checking the distance if you
324 				}
325 			}
326 
327 			if (set) {
328 				mindiff = sum;
329 				centroid = i;
330 				// if(mindiff < this.threshold){
331 				// return centroid;
332 				// }
333 			}
334 		}
335 		return centroid;
336 	}
337 
338 	@Override
339 	public String asciiHeader() {
340 		return "ASCII" + HEADER;
341 	}
342 
343 	@Override
344 	public byte[] binaryHeader() {
345 		return HEADER.getBytes();
346 	}
347 
348 	@Override
349 	public void readASCII(Scanner in) throws IOException {
350 		throw new UnsupportedOperationException("Not done!");
351 	}
352 
353 	@Override
354 	public void readBinary(DataInput dis) throws IOException {
355 		threshold = dis.readDouble();
356 		nDims = dis.readInt();
357 		final int nClusters = dis.readInt();
358 		assert (threshold > 0);
359 		codebook = new ArrayList<int[]>();
360 		for (int i = 0; i < nClusters; i++) {
361 			final byte[] wang = new byte[nDims];
362 			dis.readFully(wang, 0, nDims);
363 			final int[] cluster = new int[nDims];
364 			for (int j = 0; j < nDims; j++)
365 				cluster[j] = wang[j] & 0xFF;
366 			codebook.add(cluster);
367 		}
368 	}
369 
370 	@Override
371 	public void writeASCII(PrintWriter writer) throws IOException {
372 		writer.format("%d\n", this.threshold);
373 		writer.format("%d\n", this.nDims);
374 		writer.format("%d\n", this.numClusters());
375 		for (final int[] a : this.codebook) {
376 			writer.format("%d,", a);
377 		}
378 	}
379 
380 	@Override
381 	public void writeBinary(DataOutput dos) throws IOException {
382 		dos.writeDouble(this.threshold);
383 		dos.writeInt(this.nDims);
384 		dos.writeInt(this.numClusters());
385 		for (final int[] arr : this.codebook) {
386 			for (final int a : arr) {
387 				dos.write(a);
388 			}
389 		}
390 	}
391 
392 	@Override
393 	public int[][] getCentroids() {
394 		return this.codebook.toArray(new int[0][]);
395 	}
396 
397 	@Override
398 	public void assignDistance(int[][] data, int[] indices, float[] distances) {
399 		throw new UnsupportedOperationException("Not implemented");
400 	}
401 
402 	@Override
403 	public IntFloatPair assignDistance(int[] data) {
404 		throw new UnsupportedOperationException("Not implemented");
405 	}
406 
407 	@Override
408 	public HardAssigner<int[], ?, ?> defaultHardAssigner() {
409 		return this;
410 	}
411 
412 	/**
413 	 * The number of centroids; this potentially grows as assignments are made.
414 	 * 
415 	 * @see org.openimaj.ml.clustering.assignment.HardAssigner#size()
416 	 */
417 	@Override
418 	public int size() {
419 		return this.nDims;
420 	}
421 
422 	@Override
423 	public int[][] performClustering(int[][] data) {
424 		return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters();
425 	}
426 }