001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.hadoop.tools.fastkmeans;
031
032import java.io.ByteArrayInputStream;
033import java.io.ByteArrayOutputStream;
034import java.io.DataInputStream;
035import java.io.DataOutputStream;
036import java.io.IOException;
037import java.io.InputStream;
038import java.net.URI;
039import java.util.Random;
040
041import org.apache.hadoop.fs.FileSystem;
042import org.apache.hadoop.fs.Path;
043import org.apache.hadoop.io.BytesWritable;
044import org.apache.hadoop.io.IntWritable;
045import org.apache.hadoop.io.Text;
046import org.apache.hadoop.mapreduce.Mapper;
047import org.apache.hadoop.mapreduce.Reducer;
048import org.openimaj.hadoop.sequencefile.ExtractionState;
049import org.openimaj.hadoop.sequencefile.KeyValueDump;
050import org.openimaj.hadoop.sequencefile.NamingStrategy;
051import org.openimaj.hadoop.sequencefile.SequenceFileUtility;
052import org.openimaj.io.IOUtils;
053import org.openimaj.ml.clustering.ByteCentroidsResult;
054import org.openimaj.ml.clustering.assignment.HardAssigner;
055import org.openimaj.ml.clustering.assignment.hard.ExactByteAssigner;
056import org.openimaj.ml.clustering.assignment.hard.KDTreeByteEuclideanAssigner;
057import org.openimaj.ml.clustering.kmeans.ByteKMeans;
058import org.openimaj.util.pair.IntFloatPair;
059
060/**
061 * Approximate KMeans mapreduce implementation
062 *
063 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
064 *
065 */
066public class AKMeans {
067        /**
068         * Config option where for centroids path
069         */
070        public static final String CENTROIDS_PATH = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsPath";
071
072        /**
073         * Config option where for number of centroids K
074         */
075        public static final String CENTROIDS_K = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsK";
076
077        /**
078         * Config option where for exact mode or not
079         */
080        public static final String CENTROIDS_EXACT = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsExact";
081
082        private static final String CENTROIDS_FALLBACK_CHANCE = "uk.ac.soton.ecs.jsh2.clusterquantiser.FallbackChance";
083
084        /**
085         * the map for approximate kmeans. Uses the {@link ByteKMeans} under the
086         * hood. For each feature assign the feature to a centroid and emit with
087         * centroid as key.
088         *
089         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
090         *
091         */
092        public static class Map extends Mapper<Text, BytesWritable, IntWritable, BytesWritable> {
093                private static Path centroidsPath = null;
094                private static int k = -1;
095                private static HardAssigner<byte[], float[], IntFloatPair> assigner = null;
096                private static double randomFallbackChance;
097                private static boolean exact;
098
099                @Override
100                protected void setup(Mapper<Text, BytesWritable, IntWritable, BytesWritable>.Context context) throws IOException,
101                InterruptedException
102                {
103                        loadCluster(context);
104                }
105
106                protected static synchronized void loadCluster(
107                                Mapper<Text, BytesWritable, IntWritable, BytesWritable>.Context context) throws IOException
108                {
109                        final Path newPath = new Path(context.getConfiguration().getStrings(CENTROIDS_PATH)[0]);
110                        final boolean current = centroidsPath != null && centroidsPath.toString().equals(newPath.toString());
111                        if (!current) {
112                                k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]);
113                                exact = Boolean.parseBoolean(context.getConfiguration().getStrings(CENTROIDS_EXACT)[0]);
114                                System.out.println("This is exact mode: " + exact);
115                                randomFallbackChance = 0.01;
116                                if (context.getConfiguration().getStrings(CENTROIDS_FALLBACK_CHANCE) != null) {
117                                        randomFallbackChance = Double.parseDouble(context.getConfiguration().getStrings(
118                                                        CENTROIDS_FALLBACK_CHANCE)[0]);
119                                }
120
121                                centroidsPath = newPath;
122                                System.out.println("Loading centroids from: " + centroidsPath);
123                                final URI uri = centroidsPath.toUri();
124                                final FileSystem fs = HadoopFastKMeansOptions.getFileSystem(uri);
125                                final InputStream is = fs.open(centroidsPath);
126                                final ByteCentroidsResult centroids = IOUtils.read(is, ByteCentroidsResult.class);
127
128                                if (exact)
129                                        assigner = new ExactByteAssigner(centroids);
130                                else
131                                        assigner = new KDTreeByteEuclideanAssigner(centroids);
132                        } else {
133                                // System.out.println("No need to reload tree");
134                        }
135                }
136
137                @Override
138                public void map(Text key, BytesWritable value, Context context) throws IOException, InterruptedException {
139                        final byte[] values = value.getBytes();
140                        final byte[] points = new byte[value.getLength()];
141                        System.arraycopy(values, 0, points, 0, points.length);
142
143                        final int cluster = assigner.assign(points);
144
145                        context.write(new IntWritable(cluster), new BytesWritable(points));
146
147                        if (new Random().nextDouble() < randomFallbackChance) {
148                                context.write(new IntWritable(k + 1), new BytesWritable(points));
149                        }
150                }
151        }
152
153        private static int accumulateFromFeature(int[] sum, byte[] assigned) throws IOException {
154                if (assigned.length != sum.length)
155                        throw new IOException("Inconsistency in sum and feature length");
156                for (int i = 0; i < sum.length; i++) {
157                        sum[i] += assigned[i];
158                }
159                return 1;
160        }
161
162        private static int accumulateFromSum(int[] sum, byte[] assigned) throws IOException {
163                final int flen = (assigned.length / 4) - 1;
164                final DataInputStream dis = new DataInputStream(new ByteArrayInputStream(assigned));
165                if (flen != sum.length)
166                        throw new IOException("Inconsistency in sum and feature length");
167                final int totalAssigned = dis.readInt();
168                for (int i = 0; i < sum.length; i++) {
169                        sum[i] += dis.readInt();
170                }
171                return totalAssigned;
172        }
173
174        /**
175         * for efficiency, combine centroids early, emitting sums and k for
176         * centroids combined
177         *
178         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
179         *
180         */
181        public static class Combine extends Reducer<IntWritable, BytesWritable, IntWritable, BytesWritable> {
182                private int k;
183
184                @Override
185                public void setup(Context context) throws IOException, InterruptedException {
186                        k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]);
187                }
188
189                @Override
190                public void reduce(IntWritable key, Iterable<BytesWritable> values, Context context) throws IOException,
191                InterruptedException
192                {
193                        final int[] sum = new int[128];
194                        int totalAssigned = 0;
195                        for (final BytesWritable val : values) {
196                                // Copy the important part of the array
197                                final byte[] assigned = new byte[val.getLength()];
198                                System.arraycopy(val.getBytes(), 0, assigned, 0, assigned.length);
199                                // Skip over all the random runoff emittions
200                                if (key.get() > k) {
201                                        context.write(key, new BytesWritable(assigned));
202                                        continue;
203                                }
204                                int added = 0;
205                                // Accumulate either with feature or with an existing sum of
206                                // features
207                                if (assigned.length == 128)
208                                        added += accumulateFromFeature(sum, assigned);
209                                else
210                                        added += accumulateFromSum(sum, assigned);
211                                totalAssigned += added;
212                        }
213                        if (key.get() > k)
214                                return;
215                        // Write accumulation and current count
216                        final ByteArrayOutputStream bos = new ByteArrayOutputStream();
217                        final DataOutputStream dos = new DataOutputStream(bos);
218                        dos.writeInt(totalAssigned);
219                        for (final int i : sum) {
220                                dos.writeInt(i);
221                        }
222                        context.write(key, new BytesWritable(bos.toByteArray()));
223                }
224        }
225
226        /**
227         * The AKmeans reducer. average the combined features assigned to each
228         * centroid, emit new centroids. may (if not assigned) result in some
229         * centroids with no value.
230         *
231         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
232         *
233         */
234        public static class Reduce extends Reducer<IntWritable, BytesWritable, IntWritable, BytesWritable> {
235                private int k;
236
237                @Override
238                public void setup(Context context) throws IOException, InterruptedException {
239                        k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]);
240                }
241
242                @Override
243                public void reduce(IntWritable key, Iterable<BytesWritable> values, Context context) throws IOException,
244                InterruptedException
245                {
246                        final int[] sum = new int[128];
247                        final byte[] out = new byte[128];
248
249                        int totalAssigned = 0;
250                        for (final BytesWritable val : values) {
251                                final byte[] assigned = new byte[val.getLength()];
252
253                                System.arraycopy(val.getBytes(), 0, assigned, 0, assigned.length);
254
255                                if (key.get() > k) {
256                                        context.write(key, new BytesWritable(assigned));
257                                        continue;
258                                }
259
260                                int added = 0;
261                                if (assigned.length == 128)
262                                        added += accumulateFromFeature(sum, assigned);
263                                else
264                                        added += accumulateFromSum(sum, assigned);
265
266                                totalAssigned += added;
267                        }
268
269                        if (key.get() > k)
270                                return;
271
272                        for (int i = 0; i < sum.length; i++) {
273                                out[i] = (byte) ((sum[i] / totalAssigned));
274                        }
275
276                        context.write(key, new BytesWritable(out));
277                }
278        }
279
280        static class SelectTopKDump extends KeyValueDump<IntWritable, BytesWritable> {
281                int index = 0;
282                int randomGens = 0;
283                byte[][] centroids;
284
285                SelectTopKDump(int k) {
286                        centroids = new byte[k][];
287                }
288
289                @Override
290                public void dumpValue(IntWritable key, BytesWritable val) {
291                        if (index >= centroids.length)
292                                return;
293                        final byte[] bytes = new byte[val.getLength()];
294                        System.arraycopy(val.getBytes(), 0, bytes, 0, bytes.length);
295                        centroids[index] = bytes;
296                        index++;
297                        if (key.get() == centroids.length + 1) {
298                                randomGens++;
299                        }
300                }
301
302        }
303
304        /**
305         * Given the location of a binary dump of centroids on the HDFS, load the
306         * binary dump and construct a proper {@link ByteKMeans} instance
307         *
308         * @param centroids
309         * @param selected
310         * @param options
311         * @return {@link ByteKMeans} for the centoirds on the HDFS
312         * @throws Exception
313         */
314        public static ByteCentroidsResult
315        completeCentroids(String centroids, String selected, HadoopFastKMeansOptions options)
316                        throws Exception
317        {
318                System.out.println("Attempting to complete");
319                final Path centroidsPath = new Path(centroids);
320                SequenceFileUtility<IntWritable, BytesWritable> utility = new IntBytesSequenceMemoryUtility(
321                                centroidsPath.toUri(), true);
322                final SelectTopKDump dump = new SelectTopKDump(options.k);
323                utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, dump);
324
325                byte[][] newcentroids;
326                newcentroids = dump.centroids;
327
328                // We need to pick k - (dump.index - dump.randomGens) new items
329                System.out.println("Expecting " + options.k + " got " + dump.index + " of which " + dump.randomGens
330                                + " were random");
331                if (dump.index < options.k) {
332                        final int randomNeeded = options.k - (dump.index - dump.randomGens);
333
334                        final SequenceFileByteFeatureSelector sfbs = new SequenceFileByteFeatureSelector(selected, options.output
335                                        + "/randomswap", options);
336                        final String initialCentroids = sfbs.getRandomFeatures(randomNeeded);
337                        final Path newcentroidsPath = new Path(initialCentroids);
338                        utility = new IntBytesSequenceMemoryUtility(newcentroidsPath.toUri(), true);
339                        final SelectTopKDump neededdump = new SelectTopKDump(randomNeeded);
340                        utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, neededdump);
341                        newcentroids = neededdump.centroids;
342                }
343
344                final ByteCentroidsResult newFastKMeansCluster = new ByteCentroidsResult();
345                newFastKMeansCluster.centroids = newcentroids;
346
347                return newFastKMeansCluster;
348        }
349
350        /**
351         * load some initially selected centroids from {@link FeatureSelect} as a
352         * {@link ByteKMeans} instance
353         *
354         * @param initialCentroids
355         * @param k
356         * @return a {@link ByteKMeans}
357         * @throws IOException
358         */
359        public static ByteCentroidsResult sequenceFileToCluster(String initialCentroids, int k) throws IOException {
360                final SelectTopKDump neededdump = new SelectTopKDump(k);
361                final IntBytesSequenceMemoryUtility utility = new IntBytesSequenceMemoryUtility(initialCentroids, true);
362
363                utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, neededdump);
364
365                final ByteCentroidsResult newFastKMeansCluster = new ByteCentroidsResult();
366                newFastKMeansCluster.centroids = neededdump.centroids;
367
368                return newFastKMeansCluster;
369        }
370}