001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.tools.clusterquantiser;
031
032import gnu.trove.list.array.TIntArrayList;
033
034import java.io.File;
035import java.io.FileOutputStream;
036import java.io.FileWriter;
037import java.io.IOException;
038import java.io.ObjectOutputStream;
039import java.io.PrintWriter;
040import java.util.ArrayList;
041import java.util.List;
042import java.util.Random;
043import java.util.concurrent.Callable;
044import java.util.concurrent.ExecutorService;
045import java.util.concurrent.Executors;
046
047import org.kohsuke.args4j.CmdLineException;
048import org.openimaj.data.RandomData;
049import org.openimaj.io.IOUtils;
050import org.openimaj.ml.clustering.ByteCentroidsResult;
051import org.openimaj.ml.clustering.IntCentroidsResult;
052import org.openimaj.ml.clustering.SpatialClusters;
053import org.openimaj.ml.clustering.assignment.HardAssigner;
054import org.openimaj.ml.clustering.assignment.hard.KDTreeByteEuclideanAssigner;
055import org.openimaj.ml.clustering.assignment.hard.KDTreeIntEuclideanAssigner;
056import org.openimaj.time.Timer;
057import org.openimaj.tools.clusterquantiser.ClusterType.ClusterTypeOp;
058import org.openimaj.tools.clusterquantiser.samplebatch.SampleBatch;
059import org.openimaj.util.array.ByteArrayConverter;
060import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory;
061
062/**
063 * A tool for clustering and quantising local features.
064 * 
065 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
066 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
067 */
068public class ClusterQuantiser {
069        /**
070         * create new clusters
071         * 
072         * @param options
073         * @return clusters
074         * @throws Exception
075         */
076        public static SpatialClusters<?> do_create(ClusterQuantiserOptions options) throws Exception {
077                final File treeFile = new File(options.getTreeFile());
078                final ClusterTypeOp clusterType = options.getClusterType();
079
080                SpatialClusters<?> cluster = null;
081
082                // perform sampling if required
083                if (options.isBatchedSampleMode()) {
084                        cluster = clusterType.create(do_getSampleBatches(options));
085                        IOUtils.writeBinary(treeFile, cluster);
086                } else {
087                        final byte[][] data = do_getSamples(options);
088                        System.err.printf("Using %d records\n", data.length);
089                        cluster = clusterType.create(data);
090                        System.err.println("Writing cluster file to " + treeFile);
091                        IOUtils.writeBinary(treeFile, cluster);
092                }
093                return cluster;
094
095        }
096
097        /**
098         * Get sample batches
099         * 
100         * @param options
101         * @return batches
102         * @throws IOException
103         */
104        public static List<SampleBatch> do_getSampleBatches(ClusterQuantiserOptions options) throws IOException {
105                if (options.isSamplesFileMode()) {
106                        try {
107                                System.err.println("Attempting to read sample batch file...");
108                                return SampleBatch.readSampleBatches(options.getSamplesFile());
109                        } catch (final Exception e) {
110                                System.err.println("... Failed! ");
111                                return null;
112                        }
113                }
114
115                final List<SampleBatch> batches = new ArrayList<SampleBatch>();
116                final List<File> input_files = options.getInputFiles();
117                final FileType type = options.getFileType();
118                final int n_input_files = input_files.size();
119                final List<Header> headers = new ArrayList<Header>(n_input_files);
120
121                // read the headers and count the total number of features
122                System.err.printf("Reading input %8d / %8d", 0, n_input_files);
123                int totalFeatures = 0;
124                final int[] cumSum = new int[n_input_files + 1];
125                for (int i = 0; i < n_input_files; i++) {
126                        final Header h = type.readHeader(input_files.get(i));
127
128                        totalFeatures += h.nfeatures;
129                        cumSum[i + 1] = totalFeatures;
130                        headers.add(h);
131
132                        System.err.printf("\r%8d / %8d", i + 1, n_input_files);
133                }
134
135                System.err.println();
136                final int samples = options.getSamples();
137                if (samples <= 0 || samples > totalFeatures) {
138                        System.err.printf(
139                                        "Samples requested %d larger than total samples %d...\n",
140                                        samples, totalFeatures);
141
142                        for (int i = 0; i < n_input_files; i++) {
143                                if (cumSum[i + 1] - cumSum[i] == 0)
144                                        continue;
145                                final SampleBatch sb = new SampleBatch(type, input_files.get(i),
146                                                cumSum[i], cumSum[i + 1]);
147                                batches.add(sb);
148                                System.err.printf("\rConstructing sample batches %8d / %8d", i,
149                                                n_input_files);
150                        }
151                        System.err.println();
152                        System.err.println("Done...");
153                } else {
154                        System.err.println("Shuffling and sampling ...");
155                        // generate sample unique random numbers between 0 and totalFeatures
156                        int[] rndIndices = null;
157                        if (options.getRandomSeed() == -1)
158                                rndIndices = RandomData.getUniqueRandomInts(samples, 0,
159                                                totalFeatures);
160                        else
161                                rndIndices = RandomData.getUniqueRandomInts(samples, 0,
162                                                totalFeatures, new Random(options.getRandomSeed()));
163                        System.err.println("Done! Extracting features required");
164                        final TIntArrayList intraFileIndices = new TIntArrayList();
165                        for (int j = 0, s = 0; j < n_input_files; j++) {
166                                intraFileIndices.clear();
167
168                                // go through samples and find ones belonging to this doc
169                                for (int i = 0; i < samples; i++) {
170                                        final int idx = rndIndices[i];
171
172                                        if (idx >= cumSum[j] && idx < cumSum[j + 1]) {
173                                                intraFileIndices.add(idx - cumSum[j]);
174                                        }
175                                }
176
177                                if (intraFileIndices.size() > 0) {
178                                        final SampleBatch sb = new SampleBatch(type, input_files.get(j),
179                                                        s, s + intraFileIndices.size(),
180                                                        intraFileIndices.toArray());
181                                        batches.add(sb);
182                                        s += intraFileIndices.size();
183                                        System.err.printf("\r%8d / %8d", s, samples);
184                                }
185
186                        }
187                        System.err.println();
188                }
189                if (batches.size() > 0 && options.getSamplesFile() != null) {
190                        System.err.println("Writing samples file...");
191                        SampleBatch.writeSampleBatches(batches, options.getSamplesFile());
192                }
193                return batches;
194        }
195
196        /**
197         * Get samples
198         * 
199         * @param options
200         * @return samples
201         * @throws IOException
202         */
203        public static byte[][] do_getSamples(ClusterQuantiserOptions options)
204                        throws IOException
205        {
206
207                byte[][] data = null;
208                if (options.isSamplesFileMode()) {
209                        data = options.getSampleKeypoints();
210                } else {
211                        final List<File> input_files = options.getInputFiles();
212                        final FileType type = options.getFileType();
213                        final int n_input_files = input_files.size();
214                        final List<Header> headers = new ArrayList<Header>(n_input_files);
215
216                        // read the headers and count the total number of features
217                        System.err.printf("Reading input %8d / %8d", 0, n_input_files);
218                        int totalFeatures = 0;
219                        final int[] cumSum = new int[n_input_files + 1];
220                        for (int i = 0; i < n_input_files; i++) {
221                                final Header h = type.readHeader(input_files.get(i));
222
223                                totalFeatures += h.nfeatures;
224                                cumSum[i + 1] = totalFeatures;
225                                headers.add(h);
226
227                                System.err.printf("\r%8d / %8d", i + 1, n_input_files);
228                        }
229
230                        System.err.println();
231                        final int samples = options.getSamples();
232                        if (samples <= 0 || samples > totalFeatures) {
233                                System.err
234                                                .printf("Samples requested %d larger than total samples %d...\n",
235                                                                samples, totalFeatures);
236                                // no sampled requested or more samples requested than features
237                                // exist
238                                // so use all features
239                                data = new byte[totalFeatures][];
240
241                                for (int i = 0, j = 0; i < n_input_files; i++) {
242                                        final byte[][] fd = type.readFeatures(input_files.get(i));
243
244                                        for (int k = 0; k < fd.length; k++) {
245                                                data[j + k] = fd[k];
246                                                System.err.printf("\r%8d / %8d", j, totalFeatures);
247                                        }
248                                        j += fd.length;
249                                }
250                        } else {
251                                System.err.println("Shuffling and sampling ...");
252
253                                data = new byte[samples][];
254
255                                // generate sample unique random numbers between 0 and
256                                // totalFeatures
257                                int[] rndIndices = null;
258                                if (options.getRandomSeed() == -1)
259                                        rndIndices = RandomData.getUniqueRandomInts(samples, 0,
260                                                        totalFeatures);
261                                else
262                                        rndIndices = RandomData.getUniqueRandomInts(samples, 0,
263                                                        totalFeatures, new Random(options.getRandomSeed()));
264                                System.err.println("Done! Extracting features required");
265                                final TIntArrayList intraFileIndices = new TIntArrayList();
266                                for (int j = 0, s = 0; j < n_input_files; j++) {
267                                        intraFileIndices.clear();
268
269                                        // go through samples and find ones belonging to this doc
270                                        for (int i = 0; i < samples; i++) {
271                                                final int idx = rndIndices[i];
272
273                                                if (idx >= cumSum[j] && idx < cumSum[j + 1]) {
274                                                        intraFileIndices.add(idx - cumSum[j]);
275                                                }
276                                        }
277
278                                        if (intraFileIndices.size() > 0) {
279                                                final byte[][] f = type.readFeatures(input_files.get(j),
280                                                                intraFileIndices.toArray());
281                                                for (int i = 0; i < intraFileIndices.size(); i++, s++) {
282                                                        data[s] = f[i];
283                                                        System.err
284                                                                        .printf("\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d",
285                                                                                        s + 1, samples);
286                                                }
287                                        }
288
289                                }
290                                System.err.println();
291                        }
292                        if (data != null && options.getSamplesFile() != null) {
293                                System.err.println("Writing samples file...");
294                                final FileOutputStream fos = new FileOutputStream(
295                                                options.getSamplesFile());
296                                final ObjectOutputStream dos = new ObjectOutputStream(fos);
297                                dos.writeObject(data);
298                                dos.close();
299                        }
300                }
301                return data;
302        }
303
304        /**
305         * Print info about clusters
306         * 
307         * @param options
308         * @throws IOException
309         */
310        public static void do_info(AbstractClusterQuantiserOptions options)
311                        throws IOException
312        {
313                final SpatialClusters<?> cluster = IOUtils.read(new File(options.getTreeFile()), options.getClusterClass());
314                System.out.println("Cluster loaded...");
315                System.out.println(cluster);
316        }
317
318        /**
319         * Quantise features
320         * 
321         * @param cqo
322         * @throws IOException
323         * @throws InterruptedException
324         */
325        public static void do_quant(ClusterQuantiserOptions cqo) throws IOException, InterruptedException {
326                final ExecutorService es = Executors.newFixedThreadPool(cqo.getConcurrency(), new DaemonThreadFactory());
327
328                final List<QuantizerJob> jobs = QuantizerJob.getJobs(cqo);
329
330                System.out.format("Using %d processors\n", cqo.getConcurrency());
331                es.invokeAll(jobs);
332                es.shutdown();
333        }
334
335        static class QuantizerJob implements Callable<Boolean> {
336                SpatialClusters<?> tree;
337                HardAssigner<?, ?, ?> assigner;
338
339                List<File> inputFiles;
340                // FileType fileType;
341                // String extension;
342                // private ClusterType ctype;
343                // private File outputFile;
344                private ClusterQuantiserOptions cqo;
345                private String commonRoot;
346
347                static int count = 0;
348                static int total;
349
350                static synchronized void incr() {
351                        count++;
352                        System.err.printf(
353                                        "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d", count,
354                                        total);
355                }
356
357                protected QuantizerJob(ClusterQuantiserOptions cqo, SpatialClusters<?> tree, HardAssigner<?, ?, ?> assigner)
358                                throws IOException
359                {
360                        this.cqo = cqo;
361                        this.tree = tree;
362                        this.inputFiles = cqo.getInputFiles();
363                        this.commonRoot = cqo.getInputFileCommonRoot();
364                        this.assigner = assigner;
365                }
366
367                protected QuantizerJob(ClusterQuantiserOptions cqo,
368                                List<File> inputFiles, SpatialClusters<?> clusters, HardAssigner<?, ?, ?> assigner) throws IOException
369                {
370                        this.cqo = cqo;
371                        this.tree = clusters;
372                        this.inputFiles = inputFiles;
373                        this.commonRoot = cqo.getInputFileCommonRoot();
374                        this.assigner = assigner;
375                }
376
377                public static List<QuantizerJob> getJobs(ClusterQuantiserOptions cqo)
378                                throws IOException
379                {
380
381                        final List<QuantizerJob> jobs = new ArrayList<QuantizerJob>(
382                                        cqo.getConcurrency());
383                        final int size = cqo.getInputFiles().size() / cqo.getConcurrency();
384
385                        final SpatialClusters<?> clusters = IOUtils.read(new File(cqo.getTreeFile()), cqo.getClusterClass());
386
387                        HardAssigner<?, ?, ?> assigner;
388                        if (!cqo.exactQuant) {
389                                assigner = clusters.defaultHardAssigner();
390                        } else {
391                                if (clusters instanceof ByteCentroidsResult)
392                                        assigner = new KDTreeByteEuclideanAssigner((ByteCentroidsResult) clusters);
393                                else
394                                        assigner = new KDTreeIntEuclideanAssigner((IntCentroidsResult) clusters);
395                        }
396
397                        QuantizerJob.count = 0;
398                        QuantizerJob.total = cqo.getInputFiles().size();
399                        for (int i = 0; i < cqo.getConcurrency() - 1; i++) {
400                                final QuantizerJob job = new QuantizerJob(cqo, cqo.getInputFiles().subList(i * size, (i + 1) * size),
401                                                clusters, assigner);
402                                jobs.add(job);
403                        }
404                        // add remaining
405                        final QuantizerJob job = new QuantizerJob(cqo,
406                                        cqo.getInputFiles().subList((cqo.getConcurrency() - 1) * size,
407                                                        cqo.getInputFiles().size()), clusters, assigner);
408                        jobs.add(job);
409
410                        return jobs;
411                }
412
413                @SuppressWarnings("unchecked")
414                @Override
415                public Boolean call() throws Exception {
416                        for (int i = 0; i < inputFiles.size(); i++) {
417                                try {
418                                        File outFile = new File(inputFiles.get(i)
419                                                        + cqo.getExtension());
420                                        if (cqo.getOutputFile() != null)
421                                                outFile = new File(cqo.getOutputFile().getAbsolutePath() // /output
422                                                                + File.separator // /
423                                                                + outFile.getAbsolutePath().substring(this.commonRoot.length())); // /filename.out
424                                        if (outFile.exists() && outFile.getTotalSpace() > 0) {
425                                                incr();
426                                                continue;
427                                        }
428                                        final FeatureFile input = cqo.getFileType().read(inputFiles.get(i));
429                                        PrintWriter pw = null;
430                                        // Make the parent directory if you need to
431                                        if (!outFile.getParentFile().exists()) {
432                                                if (!outFile.getParentFile().mkdirs())
433                                                        throw new IOException("couldn't make output directory: " + outFile.getParentFile());
434                                        }
435                                        final Timer t = Timer.timer();
436                                        try {
437                                                pw = new PrintWriter(new FileWriter(outFile));
438                                                pw.format("%d\n%d\n", input.size(),
439                                                                tree.numClusters());
440                                                // int [] clusters = new int[input.size()];
441                                                for (final FeatureFileFeature fff : input) {
442                                                        int cluster = -1;
443                                                        if (tree.getClass().getName().contains("Byte"))
444                                                                cluster = ((HardAssigner<byte[], ?, ?>) assigner).assign(fff.data);
445                                                        else
446                                                                cluster = ((HardAssigner<int[], ?, ?>) assigner).assign(ByteArrayConverter
447                                                                                .byteToInt(fff.data));
448                                                        pw.format("%s %d\n", fff.location.trim(), cluster);
449                                                }
450                                        } catch (final IOException e) {
451                                                e.printStackTrace();
452                                                throw new Error(e); // IO error when writing - die.
453                                        } finally {
454                                                if (pw != null) {
455                                                        pw.flush();
456                                                        pw.close();
457                                                        input.close();
458                                                }
459
460                                        }
461                                        t.stop();
462                                        if (cqo.printTiming()) {
463                                                System.out.println("Took: " + t.duration());
464                                        }
465
466                                } catch (final Exception e) {
467                                        // Error processing an individual file; print error then
468                                        // continue
469                                        e.printStackTrace();
470                                        System.err.println("Error processing file:"
471                                                        + inputFiles.get(i));
472                                        System.err
473                                                        .println("(Exception was " + e.getMessage() + ")");
474                                }
475
476                                // System.err.printf("\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b%8d / %8d",
477                                // i+1, input_files.size());
478                                incr();
479                        }
480                        // System.out.println();
481                        return true;
482                }
483        }
484
485        /**
486         * Prepare options
487         * 
488         * @param args
489         * @return prepared options
490         * @throws InterruptedException
491         * @throws CmdLineException
492         */
493        public static ClusterQuantiserOptions mainOptions(String[] args)
494                        throws InterruptedException, CmdLineException
495        {
496                final ClusterQuantiserOptions options = new ClusterQuantiserOptions(args);
497                options.prepare();
498
499                return options;
500        }
501
502        /**
503         * The main method
504         * 
505         * @param args
506         * @throws Exception
507         */
508        public static void main(String[] args) throws Exception {
509                try {
510                        final ClusterQuantiserOptions options = mainOptions(args);
511
512                        final List<File> inputFiles = options.getInputFiles();
513
514                        if (options.getVerbosity() >= 0 && !options.isInfoMode())
515                                System.err
516                                                .printf("We have %d input files\n", inputFiles.size());
517
518                        if (options.isCreateMode()) {
519                                do_create(options);
520                        } else if (options.isInfoMode()) {
521                                do_info(options);
522                        } else if (options.isQuantMode()) {
523                                do_quant(options);
524                        } else if (options.getSamplesFile() != null
525                                        && inputFiles.size() > 0)
526                        {
527                                if (options.isBatchedSampleMode()) {
528                                        do_getSampleBatches(options);
529                                } else {
530                                        do_getSamples(options);
531                                }
532                        }
533                } catch (final CmdLineException cmdline) {
534                        System.err.print(cmdline);
535                } catch (final IOException e) {
536                        System.err.println(e.getMessage());
537                }
538        }
539}