001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.rac;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.util.ArrayList;
037import java.util.List;
038import java.util.Scanner;
039
040import org.apache.commons.math.FunctionEvaluationException;
041import org.apache.commons.math.MaxIterationsExceededException;
042import org.apache.commons.math.analysis.UnivariateRealFunction;
043import org.apache.commons.math.analysis.solvers.BisectionSolver;
044import org.openimaj.citation.annotation.Reference;
045import org.openimaj.citation.annotation.ReferenceType;
046import org.openimaj.data.DataSource;
047import org.openimaj.data.RandomData;
048import org.openimaj.ml.clustering.CentroidsProvider;
049import org.openimaj.ml.clustering.IndexClusters;
050import org.openimaj.ml.clustering.SpatialClusterer;
051import org.openimaj.ml.clustering.SpatialClusters;
052import org.openimaj.ml.clustering.assignment.HardAssigner;
053import org.openimaj.util.pair.IntFloatPair;
054
055/**
056 * An implementation of the RAC algorithm proposed by <a
057 * href="http://eprints.ecs.soton.ac.uk/21401/">Ramanan and Niranjan</a>.
058 * <p>
059 * During training, data points are selected at random. The first data point is
060 * chosen as a centroid. Every following data point is set as a new centroid if
061 * it is outside the threshold of all current centroids. In this way it is
062 * difficult to guarantee number of clusters so a minimisation function is
063 * provided to allow a close estimate of the required threshold for a given K.
064 * <p>
065 * This implementation supports int[] cluster centroids.
066 * <p>
067 * In terms of implementation, this class is a both a clusterer, assigner and
068 * the result of the clustering. This is because the RAC algorithm never ends;
069 * that is to say that if a new point is being assigned through the
070 * {@link HardAssigner} interface, and that point is more than the threshold
071 * distance from any other centroid, then a new centroid will be created for the
072 * point. If this behaviour is undesirable, the results of clustering can be
073 * "frozen" by manually constructing an assigner that takes a
074 * {@link CentroidsProvider} (or the centroids provided by calling
075 * {@link #getCentroids()}) as an argument.
076 * 
077 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
078 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
079 */
080@Reference(
081                type = ReferenceType.Inproceedings,
082                author = { "Amirthalingam Ramanan", "Mahesan Niranjan" },
083                title = "Resource-Allocating Codebook for Patch-based Face Recognition",
084                year = "2009",
085                booktitle = "IIS",
086                url = "http://eprints.ecs.soton.ac.uk/21401/")
087public class IntRAC
088                implements
089                SpatialClusters<int[]>,
090                SpatialClusterer<IntRAC, int[]>,
091                CentroidsProvider<int[]>,
092                HardAssigner<int[], float[], IntFloatPair>
093{
094        private static class ClusterMinimisationFunction implements UnivariateRealFunction {
095                private int[][] distances;
096                private int[][] samples;
097                private int nClusters;
098
099                public ClusterMinimisationFunction(int[][] samples, int[][] distances, int nClusters) {
100                        this.distances = distances;
101                        this.samples = samples;
102                        this.nClusters = nClusters;
103                }
104
105                @Override
106                public double value(double radius) throws FunctionEvaluationException {
107                        final IntRAC r = new IntRAC(radius);
108                        r.train(samples, distances);
109                        final int diff = this.nClusters - r.numClusters();
110                        return diff;
111                }
112        }
113
114        private static final String HEADER = SpatialClusters.CLUSTER_HEADER + "RAIC";
115
116        protected ArrayList<int[]> codebook;
117        protected double threshold;
118        protected int nDims;
119        protected static int[][] distances;
120        protected long totalSamples;
121
122        /**
123         * Sets the threshold to 128
124         */
125        public IntRAC() {
126                codebook = new ArrayList<int[]>();
127                this.threshold = 128;
128                this.nDims = -1;
129                this.totalSamples = 0;
130        }
131
132        /**
133         * Define the threshold at which point a new cluster will be made.
134         * 
135         * @param radiusSquared
136         */
137        public IntRAC(double radiusSquared) {
138                this();
139                this.threshold = radiusSquared;
140        }
141
142        /**
143         * Iteratively select subSamples from bKeys and try to choose a threshold
144         * which results in nClusters. This is provided to estimate threshold as
145         * this is a very data dependant value. The threshold is found using a
146         * BisectionSolver with a complete distance matrix (so make sure subSamples
147         * is reasonable)
148         * 
149         * @param bKeys
150         *            All keys to be trained against
151         * @param subSamples
152         *            number of subsamples to select from bKeys each iteration
153         * @param nClusters
154         *            number of clusters to aim for
155         */
156        public IntRAC(int[][] bKeys, int subSamples, int nClusters) {
157                this();
158
159                distances = new int[subSamples][subSamples];
160                int j = 0;
161                this.threshold = 0;
162                final int thresholdIteration = 5;
163                while (j++ < thresholdIteration) {
164                        final int[][] randomList = new int[subSamples][];
165                        final int[] randomListIndex = RandomData.getUniqueRandomInts(subSamples, 0, bKeys.length);
166                        int ri = 0;
167                        for (int k = 0; k < randomListIndex.length; k++)
168                                randomList[ri++] = bKeys[randomListIndex[k]];
169                        try {
170                                this.threshold += calculateThreshold(randomList, nClusters);
171                        } catch (final Exception e) {
172                                this.threshold += 200000;
173                        }
174                        System.out.println("Current threshold: " + this.threshold / j);
175                }
176                this.threshold /= thresholdIteration;
177        }
178
179        @SuppressWarnings("deprecation")
180        protected static double calculateThreshold(int[][] samples, int nClusters) throws MaxIterationsExceededException,
181                        FunctionEvaluationException
182        {
183                int maxDistance = 0;
184                for (int i = 0; i < samples.length; i++) {
185                        for (int j = i + 1; j < samples.length; j++) {
186                                distances[i][j] = distanceEuclidianSquared(samples[i], samples[j]);
187                                distances[j][i] = distances[i][j];
188                                if (distances[i][j] > maxDistance)
189                                        maxDistance = distances[i][j];
190                        }
191                }
192                System.out.println("Distance matrix calculated");
193                final BisectionSolver b = new BisectionSolver();
194                b.setAbsoluteAccuracy(100.0);
195                return b.solve(100, new ClusterMinimisationFunction(samples, distances, nClusters), 0, maxDistance);
196        }
197
198        int train(int[][] samples, int[][] distances) {
199                int foundLength = -1;
200                final List<Integer> codebookIndex = new ArrayList<Integer>();
201                for (int i = 0; i < samples.length; i++) {
202                        final int[] entry = samples[i];
203                        if (foundLength == -1)
204                                foundLength = entry.length;
205
206                        // all the data entries must be the same length otherwise this
207                        // doesn't make sense
208                        if (foundLength != entry.length) {
209                                this.codebook = new ArrayList<int[]>();
210                                return -1;
211                        }
212                        boolean found = false;
213                        for (final int j : codebookIndex) {
214                                if (distances[i][j] < threshold) {
215                                        found = true;
216                                        break;
217                                }
218                        }
219                        if (!found) {
220                                this.codebook.add(entry);
221                                codebookIndex.add(i);
222                        }
223                }
224                this.nDims = foundLength;
225                return 0;
226        }
227
228        @Override
229        public IntRAC cluster(int[][] data) {
230                int foundLength = -1;
231
232                for (final int[] entry : data) {
233                        if (foundLength == -1)
234                                foundLength = entry.length;
235
236                        // all the data entries must be the same length otherwise this
237                        // doesn't make sense
238                        if (foundLength != entry.length) {
239                                this.codebook = new ArrayList<int[]>();
240                                throw new RuntimeException();
241                        }
242                        boolean found = false;
243                        for (final int[] existing : this.codebook) {
244                                if (distanceEuclidianSquared(entry, existing) < threshold) {
245                                        found = true;
246                                        break;
247                                }
248                        }
249                        if (!found) {
250                                this.codebook.add(entry);
251                                if (this.codebook.size() % 1000 == 0) {
252                                        System.out.println("Codebook increased to size " + this.codebook.size());
253                                }
254                        }
255                }
256
257                return this;
258        }
259
260        @Override
261        public IntRAC cluster(DataSource<int[]> data) {
262                final int[][] dataArr = new int[data.size()][data.numDimensions()];
263
264                return cluster(dataArr);
265        }
266
267        static int distanceEuclidianSquared(int[] a, int[] b) {
268                int sum = 0;
269                for (int i = 0; i < a.length; i++) {
270                        final int diff = a[i] - b[i];
271                        sum += diff * diff;
272                }
273                return sum;
274        }
275
276        static int distanceEuclidianSquared(int[] a, int[] b, int threshold2) {
277                int sum = 0;
278
279                for (int i = 0; i < a.length; i++) {
280                        final int diff = a[i] - b[i];
281                        sum += diff * diff;
282                        if (sum > threshold2)
283                                return threshold2;
284                }
285                return sum;
286        }
287
288        @Override
289        public int numClusters() {
290                return this.codebook.size();
291        }
292
293        @Override
294        public int numDimensions() {
295                return this.nDims;
296        }
297
298        @Override
299        public int[] assign(int[][] data) {
300                final int[] centroids = new int[data.length];
301                for (int i = 0; i < data.length; i++) {
302                        final int[] entry = data[i];
303                        centroids[i] = this.assign(entry);
304                }
305                return centroids;
306        }
307
308        @Override
309        public int assign(int[] data) {
310                int mindiff = -1;
311                int centroid = -1;
312
313                for (int i = 0; i < this.numClusters(); i++) {
314                        final int[] centroids = this.codebook.get(i);
315                        int sum = 0;
316                        boolean set = true;
317
318                        for (int j = 0; j < centroids.length; j++) {
319                                final int diff = centroids[j] - data[j];
320                                sum += diff * diff;
321                                if (mindiff != -1 && mindiff < sum) {
322                                        set = false;
323                                        break; // Stop checking the distance if you
324                                }
325                        }
326
327                        if (set) {
328                                mindiff = sum;
329                                centroid = i;
330                                // if(mindiff < this.threshold){
331                                // return centroid;
332                                // }
333                        }
334                }
335                return centroid;
336        }
337
338        @Override
339        public String asciiHeader() {
340                return "ASCII" + HEADER;
341        }
342
343        @Override
344        public byte[] binaryHeader() {
345                return HEADER.getBytes();
346        }
347
348        @Override
349        public void readASCII(Scanner in) throws IOException {
350                throw new UnsupportedOperationException("Not done!");
351        }
352
353        @Override
354        public void readBinary(DataInput dis) throws IOException {
355                threshold = dis.readDouble();
356                nDims = dis.readInt();
357                final int nClusters = dis.readInt();
358                assert (threshold > 0);
359                codebook = new ArrayList<int[]>();
360                for (int i = 0; i < nClusters; i++) {
361                        final byte[] wang = new byte[nDims];
362                        dis.readFully(wang, 0, nDims);
363                        final int[] cluster = new int[nDims];
364                        for (int j = 0; j < nDims; j++)
365                                cluster[j] = wang[j] & 0xFF;
366                        codebook.add(cluster);
367                }
368        }
369
370        @Override
371        public void writeASCII(PrintWriter writer) throws IOException {
372                writer.format("%d\n", this.threshold);
373                writer.format("%d\n", this.nDims);
374                writer.format("%d\n", this.numClusters());
375                for (final int[] a : this.codebook) {
376                        writer.format("%d,", a);
377                }
378        }
379
380        @Override
381        public void writeBinary(DataOutput dos) throws IOException {
382                dos.writeDouble(this.threshold);
383                dos.writeInt(this.nDims);
384                dos.writeInt(this.numClusters());
385                for (final int[] arr : this.codebook) {
386                        for (final int a : arr) {
387                                dos.write(a);
388                        }
389                }
390        }
391
392        @Override
393        public int[][] getCentroids() {
394                return this.codebook.toArray(new int[0][]);
395        }
396
397        @Override
398        public void assignDistance(int[][] data, int[] indices, float[] distances) {
399                throw new UnsupportedOperationException("Not implemented");
400        }
401
402        @Override
403        public IntFloatPair assignDistance(int[] data) {
404                throw new UnsupportedOperationException("Not implemented");
405        }
406
407        @Override
408        public HardAssigner<int[], ?, ?> defaultHardAssigner() {
409                return this;
410        }
411
412        /**
413         * The number of centroids; this potentially grows as assignments are made.
414         * 
415         * @see org.openimaj.ml.clustering.assignment.HardAssigner#size()
416         */
417        @Override
418        public int size() {
419                return this.nDims;
420        }
421
422        @Override
423        public int[][] performClustering(int[][] data) {
424                return new IndexClusters(this.cluster(data).defaultHardAssigner().assign(data)).clusters();
425        }
426}