View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress.featlearn.cifarexps;
31  
32  import java.io.IOException;
33  
34  import org.openimaj.data.dataset.GroupedDataset;
35  import org.openimaj.data.dataset.ListDataset;
36  import org.openimaj.experiment.evaluation.classification.ClassificationResult;
37  import org.openimaj.feature.DoubleFV;
38  import org.openimaj.feature.FeatureExtractor;
39  import org.openimaj.image.MBFImage;
40  import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset;
41  import org.openimaj.math.statistics.normalisation.ZScore;
42  import org.openimaj.ml.annotation.AnnotatedObject;
43  import org.openimaj.ml.annotation.linear.LiblinearAnnotator;
44  import org.openimaj.util.function.Operation;
45  import org.openimaj.util.parallel.Parallel;
46  import org.openimaj.workinprogress.featlearn.RandomPatchSampler;
47  
48  import de.bwaldvogel.liblinear.SolverType;
49  
50  public abstract class CIFARExperimentFramework {
51  	protected final int patchSize = 6;
52  	protected final int numPatches = 400000;
53  	protected final int C = 1;
54  
55  	protected abstract void learnFeatures(double[][] patches);
56  
57  	protected abstract double[] extractFeatures(MBFImage image);
58  
59  	public double run() throws IOException {
60  		// load training data
61  		final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> trainingdata = CIFAR10Dataset
62  				.getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER);
63  
64  		// create random patches
65  		final RandomPatchSampler<MBFImage> sampler =
66  				new RandomPatchSampler<MBFImage>(trainingdata, patchSize, patchSize, numPatches);
67  
68  		final double[][] patches = new double[numPatches][];
69  		int i = 0;
70  		for (final MBFImage p : sampler) {
71  			patches[i++] = p.getDoublePixelVector();
72  
73  			if (i % 10000 == 0)
74  				System.out.format("Extracting patch %d / %d\n", i, numPatches);
75  		}
76  
77  		// Perform feature learning
78  		learnFeatures(patches);
79  
80  		// extract features
81  		final MBFImage[] trimages = new MBFImage[trainingdata.numInstances()];
82  		final String[] trclasses = new String[trainingdata.numInstances()];
83  		final double[][] trfeatures = new double[trainingdata.numInstances()][];
84  		i = 0;
85  		for (final String clz : trainingdata.getGroups()) {
86  			for (final MBFImage p : trainingdata.get(clz)) {
87  				// trfeatures[i] = extractFeatures(p);
88  				trimages[i] = p;
89  				trclasses[i] = clz;
90  				i++;
91  			}
92  		}
93  		// for (i = 0; i < trimages.length; i++) {
94  		// if (i % 100 == 0)
95  		// System.out.format("Extracting features %d / %d\n", i,
96  		// trainingdata.numInstances());
97  		// trfeatures[i] = extractFeatures(trimages[i]);
98  		// }
99  		Parallel.forRange(0, trimages.length, 1, new Operation<Parallel.IntRange>() {
100 			volatile int count = 0;
101 
102 			@Override
103 			public void perform(Parallel.IntRange range) {
104 				for (int ii = range.start; ii < range.stop; ii++) {
105 					if (count % 100 == 0)
106 						System.out.format("Extracting features %d / %d\n", count, trainingdata.numInstances());
107 					trfeatures[ii] = extractFeatures(trimages[ii]);
108 					count++;
109 				}
110 			}
111 		});
112 
113 		// feature normalisation
114 		final ZScore z = new ZScore(0.01);
115 		z.train(trfeatures);
116 		final double[][] trfeaturesz = z.normalise(trfeatures);
117 
118 		// train linear SVM
119 		final LiblinearAnnotator<double[], String> ann =
120 				new LiblinearAnnotator<double[], String>(new FeatureExtractor<DoubleFV, double[]>() {
121 					@Override
122 					public DoubleFV extractFeature(double[] object) {
123 						return new DoubleFV(object);
124 					}
125 				}, LiblinearAnnotator.Mode.MULTICLASS,
126 				SolverType.L2R_L2LOSS_SVC_DUAL, C, 0.1, 1 /* bias */, true);
127 
128 		ann.train(AnnotatedObject.createList(trfeaturesz, trclasses));
129 
130 		// load test data
131 		final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> testdata = CIFAR10Dataset
132 				.getTestImages(CIFAR10Dataset.MBFIMAGE_READER);
133 
134 		// extract test features
135 		final String[] teclasses = new String[testdata.numInstances()];
136 		double[][] tefeatures = new double[testdata.numInstances()][];
137 		i = 0;
138 		for (final String clz : testdata.getGroups()) {
139 			for (final MBFImage p : testdata.get(clz)) {
140 				tefeatures[i] = extractFeatures(p);
141 				teclasses[i] = clz;
142 				i++;
143 			}
144 		}
145 
146 		// feature normalisation (using mean and stddev learned from training
147 		// data)
148 		tefeatures = z.normalise(tefeatures);
149 
150 		// perform classification and calculate accuracy
151 		double correct = 0, incorrect = 0;
152 		for (i = 0; i < tefeatures.length; i++) {
153 			final ClassificationResult<String> res = ann.classify(tefeatures[i]);
154 
155 			if (res.getPredictedClasses().iterator().next().equals(teclasses[i]))
156 				correct++;
157 			else
158 				incorrect++;
159 		}
160 		final double acc = correct / (correct + incorrect);
161 		System.out.println("Test accuracy " + acc);
162 		return acc;
163 	}
164 }