1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.experiment.gmm.retrieval;
31
32 import java.io.File;
33 import java.io.IOException;
34 import java.io.InputStream;
35 import java.net.URL;
36 import java.util.ArrayList;
37 import java.util.Collections;
38 import java.util.Comparator;
39 import java.util.HashMap;
40 import java.util.List;
41 import java.util.Map;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.ThreadPoolExecutor;
44
45 import org.apache.commons.vfs2.FileObject;
46 import org.apache.commons.vfs2.FileSystemException;
47 import org.kohsuke.args4j.CmdLineException;
48 import org.kohsuke.args4j.CmdLineParser;
49 import org.kohsuke.args4j.Option;
50 import org.openimaj.data.identity.Identifiable;
51 import org.openimaj.feature.CachingFeatureExtractor;
52 import org.openimaj.feature.DiskCachingFeatureExtractor;
53 import org.openimaj.feature.FeatureExtractor;
54 import org.openimaj.feature.FeatureVector;
55 import org.openimaj.feature.local.LocalFeature;
56 import org.openimaj.feature.local.list.LocalFeatureList;
57 import org.openimaj.image.FImage;
58 import org.openimaj.image.ImageUtilities;
59 import org.openimaj.image.processing.resize.ResizeProcessor;
60 import org.openimaj.io.ObjectReader;
61 import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
62 import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator;
63 import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType;
64 import org.openimaj.util.function.Function;
65 import org.openimaj.util.function.Operation;
66 import org.openimaj.util.pair.IndependentPair;
67 import org.openimaj.util.pair.IntDoublePair;
68 import org.openimaj.util.parallel.Parallel;
69 import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory;
70
71
72
73
74
75 public class UKBenchGMMExperiment {
76 private final class FImageFileObjectReader implements
77 ObjectReader<FImage, FileObject> {
78 @Override
79 public FImage read(FileObject source) throws IOException {
80 return ImageUtilities.FIMAGE_READER.read(source.getContent()
81 .getInputStream());
82 }
83
84 @Override
85 public boolean canRead(FileObject source, String name) {
86 InputStream inputStream = null;
87 try {
88 inputStream = source.getContent().getInputStream();
89 return ImageUtilities.FIMAGE_READER.canRead(inputStream, name);
90 } catch (FileSystemException e) {
91 } finally {
92 if (inputStream != null) {
93 try {
94 inputStream.close();
95 } catch (IOException e) {
96 throw new RuntimeException(e);
97 }
98 }
99 }
100 return false;
101 }
102 }
103
104 private final class URLFileObjectReader implements
105 ObjectReader<URL, FileObject> {
106 @Override
107 public URL read(FileObject source) throws IOException {
108 return source.getURL();
109 }
110
111 @Override
112 public boolean canRead(FileObject source, String name) {
113 try {
114 return (source.getURL() != null);
115 } catch (FileSystemException e) {
116 return false;
117 }
118 }
119 }
120
121 private static final class IRecordWrapper<A, B> implements
122 Function<UKBenchGMMExperiment.IRecord<A>, B> {
123 Function<A, B> inner;
124
125 public IRecordWrapper(Function<A, B> extract) {
126 this.inner = extract;
127 }
128
129 @Override
130 public B apply(IRecord<A> in) {
131 return inner.apply(in.image);
132 }
133
134 public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) {
135 return new IRecordWrapper<A, B>(extract);
136 }
137 }
138
139 private static class IRecord<IMAGE> implements Identifiable {
140
141 private String id;
142 private IMAGE image;
143
144 public IRecord(String id, IMAGE image) {
145 this.id = id;
146 this.image = image;
147 }
148
149 @Override
150 public String getID() {
151 return this.id;
152 }
153
154 public static <A> IRecord<A> wrap(String id, A payload) {
155 return new IRecord<A>(id, payload);
156 }
157
158 }
159
160 private static final class IRecordReader<IMAGE> implements
161 ObjectReader<IRecord<IMAGE>, FileObject> {
162 ObjectReader<IMAGE, FileObject> reader;
163
164 public IRecordReader(ObjectReader<IMAGE, FileObject> reader) {
165 this.reader = reader;
166 }
167
168 @Override
169 public IRecord<IMAGE> read(FileObject source) throws IOException {
170 String name = source.getName().getBaseName();
171 IMAGE image = reader.read(source);
172 return new IRecord<IMAGE>(name, image);
173 }
174
175 @Override
176 public boolean canRead(FileObject source, String name) {
177 return reader.canRead(source, name);
178 }
179 }
180
181 private String ukbenchRoot = "/Users/ss/Experiments/ukbench";
182 private ResizeProcessor resize;
183 private UKBenchGroupDataset<IRecord<URL>> dataset;
184 private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract;
185 final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator();
186
187 public UKBenchGMMExperiment() {
188 setup();
189 }
190
191 public UKBenchGMMExperiment(String root) {
192 this.ukbenchRoot = root;
193 setup();
194 }
195
196 private void setup() {
197 this.dataset = new UKBenchGroupDataset<IRecord<URL>>(
198 ukbenchRoot + "/full",
199
200 new IRecordReader<URL>(new URLFileObjectReader()));
201
202 resize = new ResizeProcessor(640, 480);
203
204 Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() {
205
206 @Override
207 public MixtureOfGaussians apply(URL in) {
208
209 final DSiftFeatureExtractor feature = new DSiftFeatureExtractor();
210 final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal);
211 System.out.println("... resize");
212 FImage process = null;
213 try {
214 process = ImageUtilities.readF(in).process(resize);
215 } catch (IOException e) {
216 throw new RuntimeException(e);
217 }
218 System.out.println("... dsift");
219 LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature
220 .apply(process);
221 System.out.println("... gmm");
222 return gmmFunc.apply(apply);
223 }
224
225 };
226 this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
227 new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
228 new File(ukbenchRoot + "/gmm/dsift"),
229 FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined)))
230 );
231 }
232
233 static class UKBenchGMMExperimentOptions {
234 @Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING")
235 String input = null;
236
237 @Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN")
238 boolean preextract = false;
239
240 @Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer")
241 int object = -1;
242
243 @Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer")
244 int image = -1;
245 }
246
247 static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> {
248
249 public ObjectRecord(Integer obj1, IRecord<URL> obj2) {
250 super(obj1, obj2);
251 }
252
253 }
254
255
256
257
258
259
260 public static void main(String[] args) throws IOException, CmdLineException {
261 UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions();
262 final CmdLineParser parser = new CmdLineParser(opts);
263 parser.parseArgument(args);
264 final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input);
265 if (opts.preextract){
266 System.out.println("Preloading all ukbench features...");
267 exp.extractGroupGaussians();
268 }
269
270 if(opts.object == -1 || opts.image == -1){
271 exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() {
272
273 @Override
274 public void perform(UKBenchListDataset<IRecord<URL>> group) {
275 int object = group.getObject();
276 for (int i = 0; i < group.size(); i++) {
277 double score = exp.score(object, i);
278 System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score);
279 }
280 }
281 });
282 } else {
283 double score = exp.score(opts.object, opts.image);
284 System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score);
285 }
286 }
287
288 protected MixtureOfGaussians extract(IRecord<URL> item) {
289 return this.gmmExtract.extractFeature(item);
290 }
291
292 private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) {
293 for (int i = 0; i < this.dataset.size(); i++) {
294 operation.perform(this.dataset.get(i));
295 }
296
297 }
298
299 private void applyToEachImage(Operation<ObjectRecord> operation) {
300 for (int i = 0; i < this.dataset.size(); i++) {
301 UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i);
302 for (IRecord<URL> iRecord : ukBenchListDataset) {
303 operation.perform(new ObjectRecord(i, iRecord));
304 }
305 }
306 }
307
308 public double score(int object, int image) {
309 System.out.printf("Scoring Object %d, Image %d\n",object,image);
310 IRecord<URL> item = this.dataset.get(object).get(image);
311 final MixtureOfGaussians thisGMM = extract(item);
312 final List<IntDoublePair> scored = new ArrayList<IntDoublePair>();
313 applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() {
314
315 @Override
316 public void perform(ObjectRecord object) {
317 MixtureOfGaussians otherGMM = extract(object.getSecondObject());
318
319 double distance = comp.compare(thisGMM, otherGMM);
320 scored.add(IntDoublePair.pair(object.firstObject(), distance));
321 if(scored.size() % 200 == 0){
322 System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4));
323 }
324 }
325 });
326
327 Collections.sort(scored, new Comparator<IntDoublePair>(){
328
329 @Override
330 public int compare(IntDoublePair o1, IntDoublePair o2) {
331 return -Double.compare(o1.second, o2.second);
332 }
333
334 });
335 double good = 0;
336 for (int i = 0; i < 4; i++) {
337 if(scored.get(i).first == object) good+=1;
338 }
339 return good/4f;
340 }
341
342
343
344
345 public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() {
346 final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>();
347 ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors
348 .newFixedThreadPool(1,
349 new DaemonThreadFactory());
350 final double TOTAL = this.dataset.size() * 4;
351 Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() {
352
353 @Override
354 public void perform(Integer i) {
355 groups.put(i, extractGroupGaussians(i));
356 if(groups.size() % 200 == 0){
357 System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL);
358 }
359 }
360 }, pool);
361
362 return groups;
363 }
364
365 public List<MixtureOfGaussians> extractGroupGaussians(int i) {
366 return this.extractGroupGaussians(this.dataset.get(i));
367 }
368
369 public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) {
370 List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>();
371 int i = 0;
372 for (IRecord<URL> imageURL : ukbenchObject) {
373 MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL);
374 gaussians.add(gmm);
375 }
376 return gaussians;
377 }
378
379 }