001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.featlearn.cifarexps;
031
032import java.io.IOException;
033
034import org.openimaj.data.dataset.GroupedDataset;
035import org.openimaj.data.dataset.ListDataset;
036import org.openimaj.experiment.evaluation.classification.ClassificationResult;
037import org.openimaj.feature.DoubleFV;
038import org.openimaj.feature.FeatureExtractor;
039import org.openimaj.image.MBFImage;
040import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset;
041import org.openimaj.math.statistics.normalisation.ZScore;
042import org.openimaj.ml.annotation.AnnotatedObject;
043import org.openimaj.ml.annotation.linear.LiblinearAnnotator;
044import org.openimaj.util.function.Operation;
045import org.openimaj.util.parallel.Parallel;
046import org.openimaj.workinprogress.featlearn.RandomPatchSampler;
047
048import de.bwaldvogel.liblinear.SolverType;
049
050public abstract class CIFARExperimentFramework {
051        protected final int patchSize = 6;
052        protected final int numPatches = 400000;
053        protected final int C = 1;
054
055        protected abstract void learnFeatures(double[][] patches);
056
057        protected abstract double[] extractFeatures(MBFImage image);
058
059        public double run() throws IOException {
060                // load training data
061                final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> trainingdata = CIFAR10Dataset
062                                .getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER);
063
064                // create random patches
065                final RandomPatchSampler<MBFImage> sampler =
066                                new RandomPatchSampler<MBFImage>(trainingdata, patchSize, patchSize, numPatches);
067
068                final double[][] patches = new double[numPatches][];
069                int i = 0;
070                for (final MBFImage p : sampler) {
071                        patches[i++] = p.getDoublePixelVector();
072
073                        if (i % 10000 == 0)
074                                System.out.format("Extracting patch %d / %d\n", i, numPatches);
075                }
076
077                // Perform feature learning
078                learnFeatures(patches);
079
080                // extract features
081                final MBFImage[] trimages = new MBFImage[trainingdata.numInstances()];
082                final String[] trclasses = new String[trainingdata.numInstances()];
083                final double[][] trfeatures = new double[trainingdata.numInstances()][];
084                i = 0;
085                for (final String clz : trainingdata.getGroups()) {
086                        for (final MBFImage p : trainingdata.get(clz)) {
087                                // trfeatures[i] = extractFeatures(p);
088                                trimages[i] = p;
089                                trclasses[i] = clz;
090                                i++;
091                        }
092                }
093                // for (i = 0; i < trimages.length; i++) {
094                // if (i % 100 == 0)
095                // System.out.format("Extracting features %d / %d\n", i,
096                // trainingdata.numInstances());
097                // trfeatures[i] = extractFeatures(trimages[i]);
098                // }
099                Parallel.forRange(0, trimages.length, 1, new Operation<Parallel.IntRange>() {
100                        volatile int count = 0;
101
102                        @Override
103                        public void perform(Parallel.IntRange range) {
104                                for (int ii = range.start; ii < range.stop; ii++) {
105                                        if (count % 100 == 0)
106                                                System.out.format("Extracting features %d / %d\n", count, trainingdata.numInstances());
107                                        trfeatures[ii] = extractFeatures(trimages[ii]);
108                                        count++;
109                                }
110                        }
111                });
112
113                // feature normalisation
114                final ZScore z = new ZScore(0.01);
115                z.train(trfeatures);
116                final double[][] trfeaturesz = z.normalise(trfeatures);
117
118                // train linear SVM
119                final LiblinearAnnotator<double[], String> ann =
120                                new LiblinearAnnotator<double[], String>(new FeatureExtractor<DoubleFV, double[]>() {
121                                        @Override
122                                        public DoubleFV extractFeature(double[] object) {
123                                                return new DoubleFV(object);
124                                        }
125                                }, LiblinearAnnotator.Mode.MULTICLASS,
126                                SolverType.L2R_L2LOSS_SVC_DUAL, C, 0.1, 1 /* bias */, true);
127
128                ann.train(AnnotatedObject.createList(trfeaturesz, trclasses));
129
130                // load test data
131                final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> testdata = CIFAR10Dataset
132                                .getTestImages(CIFAR10Dataset.MBFIMAGE_READER);
133
134                // extract test features
135                final String[] teclasses = new String[testdata.numInstances()];
136                double[][] tefeatures = new double[testdata.numInstances()][];
137                i = 0;
138                for (final String clz : testdata.getGroups()) {
139                        for (final MBFImage p : testdata.get(clz)) {
140                                tefeatures[i] = extractFeatures(p);
141                                teclasses[i] = clz;
142                                i++;
143                        }
144                }
145
146                // feature normalisation (using mean and stddev learned from training
147                // data)
148                tefeatures = z.normalise(tefeatures);
149
150                // perform classification and calculate accuracy
151                double correct = 0, incorrect = 0;
152                for (i = 0; i < tefeatures.length; i++) {
153                        final ClassificationResult<String> res = ann.classify(tefeatures[i]);
154
155                        if (res.getPredictedClasses().iterator().next().equals(teclasses[i]))
156                                correct++;
157                        else
158                                incorrect++;
159                }
160                final double acc = correct / (correct + incorrect);
161                System.out.println("Test accuracy " + acc);
162                return acc;
163        }
164}