001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.workinprogress.featlearn.cifarexps; 031 032import java.io.IOException; 033 034import org.openimaj.data.dataset.GroupedDataset; 035import org.openimaj.data.dataset.ListDataset; 036import org.openimaj.experiment.evaluation.classification.ClassificationResult; 037import org.openimaj.feature.DoubleFV; 038import org.openimaj.feature.FeatureExtractor; 039import org.openimaj.image.MBFImage; 040import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset; 041import org.openimaj.math.statistics.normalisation.ZScore; 042import org.openimaj.ml.annotation.AnnotatedObject; 043import org.openimaj.ml.annotation.linear.LiblinearAnnotator; 044import org.openimaj.util.function.Operation; 045import org.openimaj.util.parallel.Parallel; 046import org.openimaj.workinprogress.featlearn.RandomPatchSampler; 047 048import de.bwaldvogel.liblinear.SolverType; 049 050public abstract class CIFARExperimentFramework { 051 protected final int patchSize = 6; 052 protected final int numPatches = 400000; 053 protected final int C = 1; 054 055 protected abstract void learnFeatures(double[][] patches); 056 057 protected abstract double[] extractFeatures(MBFImage image); 058 059 public double run() throws IOException { 060 // load training data 061 final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> trainingdata = CIFAR10Dataset 062 .getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER); 063 064 // create random patches 065 final RandomPatchSampler<MBFImage> sampler = 066 new RandomPatchSampler<MBFImage>(trainingdata, patchSize, patchSize, numPatches); 067 068 final double[][] patches = new double[numPatches][]; 069 int i = 0; 070 for (final MBFImage p : sampler) { 071 patches[i++] = p.getDoublePixelVector(); 072 073 if (i % 10000 == 0) 074 System.out.format("Extracting patch %d / %d\n", i, numPatches); 075 } 076 077 // Perform feature learning 078 learnFeatures(patches); 079 080 // extract features 081 final MBFImage[] trimages = new MBFImage[trainingdata.numInstances()]; 082 final String[] trclasses = new String[trainingdata.numInstances()]; 083 final double[][] trfeatures = new double[trainingdata.numInstances()][]; 084 i = 0; 085 for (final String clz : trainingdata.getGroups()) { 086 for (final MBFImage p : trainingdata.get(clz)) { 087 // trfeatures[i] = extractFeatures(p); 088 trimages[i] = p; 089 trclasses[i] = clz; 090 i++; 091 } 092 } 093 // for (i = 0; i < trimages.length; i++) { 094 // if (i % 100 == 0) 095 // System.out.format("Extracting features %d / %d\n", i, 096 // trainingdata.numInstances()); 097 // trfeatures[i] = extractFeatures(trimages[i]); 098 // } 099 Parallel.forRange(0, trimages.length, 1, new Operation<Parallel.IntRange>() { 100 volatile int count = 0; 101 102 @Override 103 public void perform(Parallel.IntRange range) { 104 for (int ii = range.start; ii < range.stop; ii++) { 105 if (count % 100 == 0) 106 System.out.format("Extracting features %d / %d\n", count, trainingdata.numInstances()); 107 trfeatures[ii] = extractFeatures(trimages[ii]); 108 count++; 109 } 110 } 111 }); 112 113 // feature normalisation 114 final ZScore z = new ZScore(0.01); 115 z.train(trfeatures); 116 final double[][] trfeaturesz = z.normalise(trfeatures); 117 118 // train linear SVM 119 final LiblinearAnnotator<double[], String> ann = 120 new LiblinearAnnotator<double[], String>(new FeatureExtractor<DoubleFV, double[]>() { 121 @Override 122 public DoubleFV extractFeature(double[] object) { 123 return new DoubleFV(object); 124 } 125 }, LiblinearAnnotator.Mode.MULTICLASS, 126 SolverType.L2R_L2LOSS_SVC_DUAL, C, 0.1, 1 /* bias */, true); 127 128 ann.train(AnnotatedObject.createList(trfeaturesz, trclasses)); 129 130 // load test data 131 final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> testdata = CIFAR10Dataset 132 .getTestImages(CIFAR10Dataset.MBFIMAGE_READER); 133 134 // extract test features 135 final String[] teclasses = new String[testdata.numInstances()]; 136 double[][] tefeatures = new double[testdata.numInstances()][]; 137 i = 0; 138 for (final String clz : testdata.getGroups()) { 139 for (final MBFImage p : testdata.get(clz)) { 140 tefeatures[i] = extractFeatures(p); 141 teclasses[i] = clz; 142 i++; 143 } 144 } 145 146 // feature normalisation (using mean and stddev learned from training 147 // data) 148 tefeatures = z.normalise(tefeatures); 149 150 // perform classification and calculate accuracy 151 double correct = 0, incorrect = 0; 152 for (i = 0; i < tefeatures.length; i++) { 153 final ClassificationResult<String> res = ann.classify(tefeatures[i]); 154 155 if (res.getPredictedClasses().iterator().next().equals(teclasses[i])) 156 correct++; 157 else 158 incorrect++; 159 } 160 final double acc = correct / (correct + incorrect); 161 System.out.println("Test accuracy " + acc); 162 return acc; 163 } 164}