001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.kmeans;
031
032import gnu.trove.list.array.TIntArrayList;
033import gnu.trove.procedure.TIntProcedure;
034
035import java.io.DataInput;
036import java.io.DataOutput;
037import java.io.IOException;
038import java.io.PrintWriter;
039import java.util.Arrays;
040import java.util.Scanner;
041
042import org.openimaj.feature.DoubleFVComparison;
043import org.openimaj.ml.clustering.CentroidsProvider;
044import org.openimaj.ml.clustering.Clusters;
045import org.openimaj.ml.clustering.SpatialClusterer;
046import org.openimaj.ml.clustering.SpatialClusters;
047import org.openimaj.ml.clustering.assignment.HardAssigner;
048import org.openimaj.ml.clustering.assignment.hard.ExactDoubleAssigner;
049import org.openimaj.util.pair.IntDoublePair;
050
051/**
052 * The result of a {@link SpatialClusterer} that just produces a flat set of
053 * centroids.
054 *
055 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
056 */
057public class SphericalKMeansResult implements SpatialClusters<double[]>, CentroidsProvider<double[]> {
058        final static String HEADER = Clusters.CLUSTER_HEADER + "SpKM";
059
060        /** The centroids of the clusters */
061        public double[][] centroids;
062
063        /** The assignments of the training data to clusters */
064        public int[] assignments;
065
066        @Override
067        public boolean equals(Object obj) {
068                if (!(obj instanceof SphericalKMeansResult))
069                        return false;
070
071                final SphericalKMeansResult other = (SphericalKMeansResult) obj;
072                for (int i = 0; i < this.centroids.length; i++) {
073                        if (!Arrays.equals(this.centroids[i], other.centroids[i]))
074                                return false;
075                }
076                return true;
077        }
078
079        @Override
080        public String asciiHeader() {
081                return "ASCII" + HEADER;
082        }
083
084        @Override
085        public byte[] binaryHeader() {
086                return HEADER.getBytes();
087        }
088
089        @Override
090        public void readASCII(Scanner br) throws IOException {
091                // Read Header
092                final int K = Integer.parseInt(br.nextLine().trim());
093                final int M = Integer.parseInt(br.nextLine().trim());
094
095                centroids = new double[K][M];
096                for (int k = 0; k < K; k++) {
097                        final String[] parts = br.nextLine().split(",");
098
099                        for (int d = 0; d < M; d++) {
100                                centroids[k][d] = Double.parseDouble(parts[d]);
101                        }
102                }
103
104                final int A = Integer.parseInt(br.nextLine().trim());
105                assignments = new int[A];
106                for (int a = 0; a < A; a++) {
107                        assignments[a] = Integer.parseInt(br.nextLine().trim());
108                }
109        }
110
111        @Override
112        public void readBinary(DataInput in) throws IOException {
113                final int K = in.readInt();
114                final int M = in.readInt();
115
116                centroids = new double[K][M];
117
118                for (int k = 0; k < K; k++) {
119                        for (int d = 0; d < M; d++) {
120                                centroids[k][d] = in.readDouble();
121                        }
122                }
123
124                final int A = in.readInt();
125                assignments = new int[A];
126                for (int a = 0; a < A; a++) {
127                        assignments[a] = in.readInt();
128                }
129        }
130
131        @Override
132        public void writeASCII(PrintWriter writer) throws IOException {
133                writer.println(centroids.length);
134                writer.println(centroids[0].length);
135
136                for (int k = 0; k < centroids.length; k++) {
137                        for (int d = 0; d < centroids[0].length; d++) {
138                                writer.print(centroids[k][d] + ",");
139                        }
140                        writer.println();
141                }
142
143                writer.println(assignments.length);
144                for (int a = 0; a < assignments.length; a++) {
145                        writer.println(assignments[a]);
146                }
147        }
148
149        @Override
150        public void writeBinary(DataOutput out) throws IOException {
151                out.writeInt(centroids.length);
152                out.writeInt(centroids[0].length);
153
154                for (int k = 0; k < centroids.length; k++) {
155                        for (int d = 0; d < centroids[0].length; d++) {
156                                out.writeDouble(centroids[k][d]);
157                        }
158                }
159
160                out.writeInt(assignments.length);
161                for (int a = 0; a < assignments.length; a++) {
162                        out.writeInt(assignments[a]);
163                }
164        }
165
166        @Override
167        public String toString() {
168                String str = "";
169                str += "DoubleCentroidsResult" + "\n";
170                str += "No. of Clusters: " + centroids.length + "\n";
171                str += "No. of Dimensions: " + centroids[0].length + "\n";
172                return str;
173        }
174
175        @Override
176        public double[][] getCentroids() {
177                return this.centroids;
178        }
179
180        @Override
181        public HardAssigner<double[], double[], IntDoublePair> defaultHardAssigner() {
182                return new ExactDoubleAssigner(this, DoubleFVComparison.INNER_PRODUCT);
183        }
184
185        @Override
186        public int numDimensions() {
187                return centroids[0].length;
188        }
189
190        @Override
191        public int numClusters() {
192                return centroids.length;
193        }
194
195        /**
196         * Compute the histogram of number of assignments to each cluster
197         *
198         * @return the histogram
199         */
200        public int[] getAssignmentHistogram() {
201                final int[] hist = new int[centroids.length];
202
203                for (int i = 0; i < assignments.length; i++) {
204                        hist[assignments[i]]++;
205                }
206
207                return hist;
208        }
209
210        /**
211         * Filter the cluster centroids be removing those with less than threshold
212         * items
213         *
214         * @param threshold
215         *            minimum number of items
216         * @return the filtered clusters
217         */
218        public double[][] filter(int threshold) {
219                final int[] hist = getAssignmentHistogram();
220                final TIntArrayList toKeep = new TIntArrayList();
221                for (int i = 0; i < hist.length; i++) {
222                        if (hist[i] > threshold) {
223                                toKeep.add(i);
224                        }
225                }
226
227                final double[][] fcen = new double[toKeep.size()][];
228                toKeep.forEach(new TIntProcedure() {
229                        int i = 0;
230
231                        @Override
232                        public boolean execute(int value) {
233                                fcen[i++] = centroids[value];
234                                return true;
235                        }
236                });
237
238                return fcen;
239        }
240}