1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.featlearn.cifarexps;
31
32 import java.io.IOException;
33
34 import org.openimaj.data.dataset.GroupedDataset;
35 import org.openimaj.data.dataset.ListDataset;
36 import org.openimaj.experiment.evaluation.classification.ClassificationResult;
37 import org.openimaj.feature.DoubleFV;
38 import org.openimaj.feature.FeatureExtractor;
39 import org.openimaj.image.MBFImage;
40 import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset;
41 import org.openimaj.math.statistics.normalisation.ZScore;
42 import org.openimaj.ml.annotation.AnnotatedObject;
43 import org.openimaj.ml.annotation.linear.LiblinearAnnotator;
44 import org.openimaj.util.function.Operation;
45 import org.openimaj.util.parallel.Parallel;
46 import org.openimaj.workinprogress.featlearn.RandomPatchSampler;
47
48 import de.bwaldvogel.liblinear.SolverType;
49
50 public abstract class CIFARExperimentFramework {
51 protected final int patchSize = 6;
52 protected final int numPatches = 400000;
53 protected final int C = 1;
54
55 protected abstract void learnFeatures(double[][] patches);
56
57 protected abstract double[] extractFeatures(MBFImage image);
58
59 public double run() throws IOException {
60
61 final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> trainingdata = CIFAR10Dataset
62 .getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER);
63
64
65 final RandomPatchSampler<MBFImage> sampler =
66 new RandomPatchSampler<MBFImage>(trainingdata, patchSize, patchSize, numPatches);
67
68 final double[][] patches = new double[numPatches][];
69 int i = 0;
70 for (final MBFImage p : sampler) {
71 patches[i++] = p.getDoublePixelVector();
72
73 if (i % 10000 == 0)
74 System.out.format("Extracting patch %d / %d\n", i, numPatches);
75 }
76
77
78 learnFeatures(patches);
79
80
81 final MBFImage[] trimages = new MBFImage[trainingdata.numInstances()];
82 final String[] trclasses = new String[trainingdata.numInstances()];
83 final double[][] trfeatures = new double[trainingdata.numInstances()][];
84 i = 0;
85 for (final String clz : trainingdata.getGroups()) {
86 for (final MBFImage p : trainingdata.get(clz)) {
87
88 trimages[i] = p;
89 trclasses[i] = clz;
90 i++;
91 }
92 }
93
94
95
96
97
98
99 Parallel.forRange(0, trimages.length, 1, new Operation<Parallel.IntRange>() {
100 volatile int count = 0;
101
102 @Override
103 public void perform(Parallel.IntRange range) {
104 for (int ii = range.start; ii < range.stop; ii++) {
105 if (count % 100 == 0)
106 System.out.format("Extracting features %d / %d\n", count, trainingdata.numInstances());
107 trfeatures[ii] = extractFeatures(trimages[ii]);
108 count++;
109 }
110 }
111 });
112
113
114 final ZScore z = new ZScore(0.01);
115 z.train(trfeatures);
116 final double[][] trfeaturesz = z.normalise(trfeatures);
117
118
119 final LiblinearAnnotator<double[], String> ann =
120 new LiblinearAnnotator<double[], String>(new FeatureExtractor<DoubleFV, double[]>() {
121 @Override
122 public DoubleFV extractFeature(double[] object) {
123 return new DoubleFV(object);
124 }
125 }, LiblinearAnnotator.Mode.MULTICLASS,
126 SolverType.L2R_L2LOSS_SVC_DUAL, C, 0.1, 1
127
128 ann.train(AnnotatedObject.createList(trfeaturesz, trclasses));
129
130
131 final GroupedDataset<String, ListDataset<MBFImage>, MBFImage> testdata = CIFAR10Dataset
132 .getTestImages(CIFAR10Dataset.MBFIMAGE_READER);
133
134
135 final String[] teclasses = new String[testdata.numInstances()];
136 double[][] tefeatures = new double[testdata.numInstances()][];
137 i = 0;
138 for (final String clz : testdata.getGroups()) {
139 for (final MBFImage p : testdata.get(clz)) {
140 tefeatures[i] = extractFeatures(p);
141 teclasses[i] = clz;
142 i++;
143 }
144 }
145
146
147
148 tefeatures = z.normalise(tefeatures);
149
150
151 double correct = 0, incorrect = 0;
152 for (i = 0; i < tefeatures.length; i++) {
153 final ClassificationResult<String> res = ann.classify(tefeatures[i]);
154
155 if (res.getPredictedClasses().iterator().next().equals(teclasses[i]))
156 correct++;
157 else
158 incorrect++;
159 }
160 final double acc = correct / (correct + incorrect);
161 System.out.println("Test accuracy " + acc);
162 return acc;
163 }
164 }