001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.hadoop.tools.fastkmeans; 031 032import java.io.ByteArrayInputStream; 033import java.io.ByteArrayOutputStream; 034import java.io.DataInputStream; 035import java.io.DataOutputStream; 036import java.io.IOException; 037import java.io.InputStream; 038import java.net.URI; 039import java.util.Random; 040 041import org.apache.hadoop.fs.FileSystem; 042import org.apache.hadoop.fs.Path; 043import org.apache.hadoop.io.BytesWritable; 044import org.apache.hadoop.io.IntWritable; 045import org.apache.hadoop.io.Text; 046import org.apache.hadoop.mapreduce.Mapper; 047import org.apache.hadoop.mapreduce.Reducer; 048import org.openimaj.hadoop.sequencefile.ExtractionState; 049import org.openimaj.hadoop.sequencefile.KeyValueDump; 050import org.openimaj.hadoop.sequencefile.NamingStrategy; 051import org.openimaj.hadoop.sequencefile.SequenceFileUtility; 052import org.openimaj.io.IOUtils; 053import org.openimaj.ml.clustering.ByteCentroidsResult; 054import org.openimaj.ml.clustering.assignment.HardAssigner; 055import org.openimaj.ml.clustering.assignment.hard.ExactByteAssigner; 056import org.openimaj.ml.clustering.assignment.hard.KDTreeByteEuclideanAssigner; 057import org.openimaj.ml.clustering.kmeans.ByteKMeans; 058import org.openimaj.util.pair.IntFloatPair; 059 060/** 061 * Approximate KMeans mapreduce implementation 062 * 063 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 064 * 065 */ 066public class AKMeans { 067 /** 068 * Config option where for centroids path 069 */ 070 public static final String CENTROIDS_PATH = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsPath"; 071 072 /** 073 * Config option where for number of centroids K 074 */ 075 public static final String CENTROIDS_K = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsK"; 076 077 /** 078 * Config option where for exact mode or not 079 */ 080 public static final String CENTROIDS_EXACT = "uk.ac.soton.ecs.jsh2.clusterquantiser.CentroidsExact"; 081 082 private static final String CENTROIDS_FALLBACK_CHANCE = "uk.ac.soton.ecs.jsh2.clusterquantiser.FallbackChance"; 083 084 /** 085 * the map for approximate kmeans. Uses the {@link ByteKMeans} under the 086 * hood. For each feature assign the feature to a centroid and emit with 087 * centroid as key. 088 * 089 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 090 * 091 */ 092 public static class Map extends Mapper<Text, BytesWritable, IntWritable, BytesWritable> { 093 private static Path centroidsPath = null; 094 private static int k = -1; 095 private static HardAssigner<byte[], float[], IntFloatPair> assigner = null; 096 private static double randomFallbackChance; 097 private static boolean exact; 098 099 @Override 100 protected void setup(Mapper<Text, BytesWritable, IntWritable, BytesWritable>.Context context) throws IOException, 101 InterruptedException 102 { 103 loadCluster(context); 104 } 105 106 protected static synchronized void loadCluster( 107 Mapper<Text, BytesWritable, IntWritable, BytesWritable>.Context context) throws IOException 108 { 109 final Path newPath = new Path(context.getConfiguration().getStrings(CENTROIDS_PATH)[0]); 110 final boolean current = centroidsPath != null && centroidsPath.toString().equals(newPath.toString()); 111 if (!current) { 112 k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]); 113 exact = Boolean.parseBoolean(context.getConfiguration().getStrings(CENTROIDS_EXACT)[0]); 114 System.out.println("This is exact mode: " + exact); 115 randomFallbackChance = 0.01; 116 if (context.getConfiguration().getStrings(CENTROIDS_FALLBACK_CHANCE) != null) { 117 randomFallbackChance = Double.parseDouble(context.getConfiguration().getStrings( 118 CENTROIDS_FALLBACK_CHANCE)[0]); 119 } 120 121 centroidsPath = newPath; 122 System.out.println("Loading centroids from: " + centroidsPath); 123 final URI uri = centroidsPath.toUri(); 124 final FileSystem fs = HadoopFastKMeansOptions.getFileSystem(uri); 125 final InputStream is = fs.open(centroidsPath); 126 final ByteCentroidsResult centroids = IOUtils.read(is, ByteCentroidsResult.class); 127 128 if (exact) 129 assigner = new ExactByteAssigner(centroids); 130 else 131 assigner = new KDTreeByteEuclideanAssigner(centroids); 132 } else { 133 // System.out.println("No need to reload tree"); 134 } 135 } 136 137 @Override 138 public void map(Text key, BytesWritable value, Context context) throws IOException, InterruptedException { 139 final byte[] values = value.getBytes(); 140 final byte[] points = new byte[value.getLength()]; 141 System.arraycopy(values, 0, points, 0, points.length); 142 143 final int cluster = assigner.assign(points); 144 145 context.write(new IntWritable(cluster), new BytesWritable(points)); 146 147 if (new Random().nextDouble() < randomFallbackChance) { 148 context.write(new IntWritable(k + 1), new BytesWritable(points)); 149 } 150 } 151 } 152 153 private static int accumulateFromFeature(int[] sum, byte[] assigned) throws IOException { 154 if (assigned.length != sum.length) 155 throw new IOException("Inconsistency in sum and feature length"); 156 for (int i = 0; i < sum.length; i++) { 157 sum[i] += assigned[i]; 158 } 159 return 1; 160 } 161 162 private static int accumulateFromSum(int[] sum, byte[] assigned) throws IOException { 163 final int flen = (assigned.length / 4) - 1; 164 final DataInputStream dis = new DataInputStream(new ByteArrayInputStream(assigned)); 165 if (flen != sum.length) 166 throw new IOException("Inconsistency in sum and feature length"); 167 final int totalAssigned = dis.readInt(); 168 for (int i = 0; i < sum.length; i++) { 169 sum[i] += dis.readInt(); 170 } 171 return totalAssigned; 172 } 173 174 /** 175 * for efficiency, combine centroids early, emitting sums and k for 176 * centroids combined 177 * 178 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 179 * 180 */ 181 public static class Combine extends Reducer<IntWritable, BytesWritable, IntWritable, BytesWritable> { 182 private int k; 183 184 @Override 185 public void setup(Context context) throws IOException, InterruptedException { 186 k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]); 187 } 188 189 @Override 190 public void reduce(IntWritable key, Iterable<BytesWritable> values, Context context) throws IOException, 191 InterruptedException 192 { 193 final int[] sum = new int[128]; 194 int totalAssigned = 0; 195 for (final BytesWritable val : values) { 196 // Copy the important part of the array 197 final byte[] assigned = new byte[val.getLength()]; 198 System.arraycopy(val.getBytes(), 0, assigned, 0, assigned.length); 199 // Skip over all the random runoff emittions 200 if (key.get() > k) { 201 context.write(key, new BytesWritable(assigned)); 202 continue; 203 } 204 int added = 0; 205 // Accumulate either with feature or with an existing sum of 206 // features 207 if (assigned.length == 128) 208 added += accumulateFromFeature(sum, assigned); 209 else 210 added += accumulateFromSum(sum, assigned); 211 totalAssigned += added; 212 } 213 if (key.get() > k) 214 return; 215 // Write accumulation and current count 216 final ByteArrayOutputStream bos = new ByteArrayOutputStream(); 217 final DataOutputStream dos = new DataOutputStream(bos); 218 dos.writeInt(totalAssigned); 219 for (final int i : sum) { 220 dos.writeInt(i); 221 } 222 context.write(key, new BytesWritable(bos.toByteArray())); 223 } 224 } 225 226 /** 227 * The AKmeans reducer. average the combined features assigned to each 228 * centroid, emit new centroids. may (if not assigned) result in some 229 * centroids with no value. 230 * 231 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 232 * 233 */ 234 public static class Reduce extends Reducer<IntWritable, BytesWritable, IntWritable, BytesWritable> { 235 private int k; 236 237 @Override 238 public void setup(Context context) throws IOException, InterruptedException { 239 k = Integer.parseInt(context.getConfiguration().getStrings(CENTROIDS_K)[0]); 240 } 241 242 @Override 243 public void reduce(IntWritable key, Iterable<BytesWritable> values, Context context) throws IOException, 244 InterruptedException 245 { 246 final int[] sum = new int[128]; 247 final byte[] out = new byte[128]; 248 249 int totalAssigned = 0; 250 for (final BytesWritable val : values) { 251 final byte[] assigned = new byte[val.getLength()]; 252 253 System.arraycopy(val.getBytes(), 0, assigned, 0, assigned.length); 254 255 if (key.get() > k) { 256 context.write(key, new BytesWritable(assigned)); 257 continue; 258 } 259 260 int added = 0; 261 if (assigned.length == 128) 262 added += accumulateFromFeature(sum, assigned); 263 else 264 added += accumulateFromSum(sum, assigned); 265 266 totalAssigned += added; 267 } 268 269 if (key.get() > k) 270 return; 271 272 for (int i = 0; i < sum.length; i++) { 273 out[i] = (byte) ((sum[i] / totalAssigned)); 274 } 275 276 context.write(key, new BytesWritable(out)); 277 } 278 } 279 280 static class SelectTopKDump extends KeyValueDump<IntWritable, BytesWritable> { 281 int index = 0; 282 int randomGens = 0; 283 byte[][] centroids; 284 285 SelectTopKDump(int k) { 286 centroids = new byte[k][]; 287 } 288 289 @Override 290 public void dumpValue(IntWritable key, BytesWritable val) { 291 if (index >= centroids.length) 292 return; 293 final byte[] bytes = new byte[val.getLength()]; 294 System.arraycopy(val.getBytes(), 0, bytes, 0, bytes.length); 295 centroids[index] = bytes; 296 index++; 297 if (key.get() == centroids.length + 1) { 298 randomGens++; 299 } 300 } 301 302 } 303 304 /** 305 * Given the location of a binary dump of centroids on the HDFS, load the 306 * binary dump and construct a proper {@link ByteKMeans} instance 307 * 308 * @param centroids 309 * @param selected 310 * @param options 311 * @return {@link ByteKMeans} for the centoirds on the HDFS 312 * @throws Exception 313 */ 314 public static ByteCentroidsResult 315 completeCentroids(String centroids, String selected, HadoopFastKMeansOptions options) 316 throws Exception 317 { 318 System.out.println("Attempting to complete"); 319 final Path centroidsPath = new Path(centroids); 320 SequenceFileUtility<IntWritable, BytesWritable> utility = new IntBytesSequenceMemoryUtility( 321 centroidsPath.toUri(), true); 322 final SelectTopKDump dump = new SelectTopKDump(options.k); 323 utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, dump); 324 325 byte[][] newcentroids; 326 newcentroids = dump.centroids; 327 328 // We need to pick k - (dump.index - dump.randomGens) new items 329 System.out.println("Expecting " + options.k + " got " + dump.index + " of which " + dump.randomGens 330 + " were random"); 331 if (dump.index < options.k) { 332 final int randomNeeded = options.k - (dump.index - dump.randomGens); 333 334 final SequenceFileByteFeatureSelector sfbs = new SequenceFileByteFeatureSelector(selected, options.output 335 + "/randomswap", options); 336 final String initialCentroids = sfbs.getRandomFeatures(randomNeeded); 337 final Path newcentroidsPath = new Path(initialCentroids); 338 utility = new IntBytesSequenceMemoryUtility(newcentroidsPath.toUri(), true); 339 final SelectTopKDump neededdump = new SelectTopKDump(randomNeeded); 340 utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, neededdump); 341 newcentroids = neededdump.centroids; 342 } 343 344 final ByteCentroidsResult newFastKMeansCluster = new ByteCentroidsResult(); 345 newFastKMeansCluster.centroids = newcentroids; 346 347 return newFastKMeansCluster; 348 } 349 350 /** 351 * load some initially selected centroids from {@link FeatureSelect} as a 352 * {@link ByteKMeans} instance 353 * 354 * @param initialCentroids 355 * @param k 356 * @return a {@link ByteKMeans} 357 * @throws IOException 358 */ 359 public static ByteCentroidsResult sequenceFileToCluster(String initialCentroids, int k) throws IOException { 360 final SelectTopKDump neededdump = new SelectTopKDump(k); 361 final IntBytesSequenceMemoryUtility utility = new IntBytesSequenceMemoryUtility(initialCentroids, true); 362 363 utility.exportData(NamingStrategy.KEY, new ExtractionState(), 0, neededdump); 364 365 final ByteCentroidsResult newFastKMeansCluster = new ByteCentroidsResult(); 366 newFastKMeansCluster.centroids = neededdump.centroids; 367 368 return newFastKMeansCluster; 369 } 370}