View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.experiment.gmm.retrieval;
31  
32  import java.io.File;
33  import java.io.IOException;
34  import java.io.InputStream;
35  import java.net.URL;
36  import java.util.ArrayList;
37  import java.util.Collections;
38  import java.util.Comparator;
39  import java.util.HashMap;
40  import java.util.List;
41  import java.util.Map;
42  import java.util.concurrent.Executors;
43  import java.util.concurrent.ThreadPoolExecutor;
44  
45  import org.apache.commons.vfs2.FileObject;
46  import org.apache.commons.vfs2.FileSystemException;
47  import org.kohsuke.args4j.CmdLineException;
48  import org.kohsuke.args4j.CmdLineParser;
49  import org.kohsuke.args4j.Option;
50  import org.openimaj.data.identity.Identifiable;
51  import org.openimaj.feature.CachingFeatureExtractor;
52  import org.openimaj.feature.DiskCachingFeatureExtractor;
53  import org.openimaj.feature.FeatureExtractor;
54  import org.openimaj.feature.FeatureVector;
55  import org.openimaj.feature.local.LocalFeature;
56  import org.openimaj.feature.local.list.LocalFeatureList;
57  import org.openimaj.image.FImage;
58  import org.openimaj.image.ImageUtilities;
59  import org.openimaj.image.processing.resize.ResizeProcessor;
60  import org.openimaj.io.ObjectReader;
61  import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
62  import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator;
63  import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType;
64  import org.openimaj.util.function.Function;
65  import org.openimaj.util.function.Operation;
66  import org.openimaj.util.pair.IndependentPair;
67  import org.openimaj.util.pair.IntDoublePair;
68  import org.openimaj.util.parallel.Parallel;
69  import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory;
70  
71  /**
72   * 
73   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
74   */
75  public class UKBenchGMMExperiment {
76  	private final class FImageFileObjectReader implements
77  			ObjectReader<FImage, FileObject> {
78  		@Override
79  		public FImage read(FileObject source) throws IOException {
80  			return ImageUtilities.FIMAGE_READER.read(source.getContent()
81  					.getInputStream());
82  		}
83  
84  		@Override
85  		public boolean canRead(FileObject source, String name) {
86  			InputStream inputStream = null;
87  			try {
88  				inputStream = source.getContent().getInputStream();
89  				return ImageUtilities.FIMAGE_READER.canRead(inputStream, name);
90  			} catch (FileSystemException e) {
91  			} finally {
92  				if (inputStream != null) {
93  					try {
94  						inputStream.close();
95  					} catch (IOException e) {
96  						throw new RuntimeException(e);
97  					}
98  				}
99  			}
100 			return false;
101 		}
102 	}
103 
104 	private final class URLFileObjectReader implements
105 			ObjectReader<URL, FileObject> {
106 		@Override
107 		public URL read(FileObject source) throws IOException {
108 			return source.getURL();
109 		}
110 
111 		@Override
112 		public boolean canRead(FileObject source, String name) {
113 			try {
114 				return (source.getURL() != null);
115 			} catch (FileSystemException e) {
116 				return false;
117 			}
118 		}
119 	}
120 
121 	private static final class IRecordWrapper<A, B> implements
122 			Function<UKBenchGMMExperiment.IRecord<A>, B> {
123 		Function<A, B> inner;
124 
125 		public IRecordWrapper(Function<A, B> extract) {
126 			this.inner = extract;
127 		}
128 
129 		@Override
130 		public B apply(IRecord<A> in) {
131 			return inner.apply(in.image);
132 		}
133 
134 		public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) {
135 			return new IRecordWrapper<A, B>(extract);
136 		}
137 	}
138 
139 	private static class IRecord<IMAGE> implements Identifiable {
140 
141 		private String id;
142 		private IMAGE image;
143 
144 		public IRecord(String id, IMAGE image) {
145 			this.id = id;
146 			this.image = image;
147 		}
148 
149 		@Override
150 		public String getID() {
151 			return this.id;
152 		}
153 
154 		public static <A> IRecord<A> wrap(String id, A payload) {
155 			return new IRecord<A>(id, payload);
156 		}
157 
158 	}
159 
160 	private static final class IRecordReader<IMAGE> implements
161 			ObjectReader<IRecord<IMAGE>, FileObject> {
162 		ObjectReader<IMAGE, FileObject> reader;
163 
164 		public IRecordReader(ObjectReader<IMAGE, FileObject> reader) {
165 			this.reader = reader;
166 		}
167 
168 		@Override
169 		public IRecord<IMAGE> read(FileObject source) throws IOException {
170 			String name = source.getName().getBaseName();
171 			IMAGE image = reader.read(source);
172 			return new IRecord<IMAGE>(name, image);
173 		}
174 
175 		@Override
176 		public boolean canRead(FileObject source, String name) {
177 			return reader.canRead(source, name);
178 		}
179 	}
180 
181 	private String ukbenchRoot = "/Users/ss/Experiments/ukbench";
182 	private ResizeProcessor resize;
183 	private UKBenchGroupDataset<IRecord<URL>> dataset;
184 	private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract;
185 	final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator();
186 
187 	public UKBenchGMMExperiment() {
188 		setup();
189 	}
190 
191 	public UKBenchGMMExperiment(String root) {
192 		this.ukbenchRoot = root;
193 		setup();
194 	}
195 
196 	private void setup() {
197 		this.dataset = new UKBenchGroupDataset<IRecord<URL>>(
198 				ukbenchRoot + "/full",
199 				// new IRecordReader<FImage>(new FImageFileObjectReader())
200 				new IRecordReader<URL>(new URLFileObjectReader()));
201 
202 		resize = new ResizeProcessor(640, 480);
203 
204 		Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() {
205 
206 			@Override
207 			public MixtureOfGaussians apply(URL in) {
208 				
209 				final DSiftFeatureExtractor feature = new DSiftFeatureExtractor();
210 				final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal);
211 				System.out.println("... resize");
212 				FImage process = null;
213 				try {
214 					process = ImageUtilities.readF(in).process(resize);
215 				} catch (IOException e) {
216 					throw new RuntimeException(e);
217 				}
218 				System.out.println("... dsift");
219 				LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature
220 						.apply(process);
221 				System.out.println("... gmm");
222 				return gmmFunc.apply(apply);
223 			}
224 
225 		};
226 		this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
227 				new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
228 						new File(ukbenchRoot + "/gmm/dsift"),
229 						FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined)))
230 		);
231 	}
232 
233 	static class UKBenchGMMExperimentOptions {
234 		@Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING")
235 		String input = null;
236 
237 		@Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN")
238 		boolean preextract = false;
239 		
240 		@Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer")
241 		int object = -1;
242 		
243 		@Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer")
244 		int image = -1;
245 	}
246 
247 	static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> {
248 
249 		public ObjectRecord(Integer obj1, IRecord<URL> obj2) {
250 			super(obj1, obj2);
251 		}
252 
253 	}
254 
255 	/**
256 	 * @param args
257 	 * @throws IOException
258 	 * @throws CmdLineException
259 	 */
260 	public static void main(String[] args) throws IOException, CmdLineException {
261 		UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions();
262 		final CmdLineParser parser = new CmdLineParser(opts);
263 		parser.parseArgument(args);
264 		final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input);
265 		if (opts.preextract){
266 			System.out.println("Preloading all ukbench features...");
267 			exp.extractGroupGaussians();			
268 		}
269 		
270 		if(opts.object == -1 || opts.image == -1){			
271 			exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() {
272 				
273 				@Override
274 				public void perform(UKBenchListDataset<IRecord<URL>> group) {
275 					int object = group.getObject();
276 					for (int i = 0; i < group.size(); i++) {
277 						double score = exp.score(object, i);
278 						System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score);
279 					}
280 				}
281 			});
282 		} else {
283 			double score = exp.score(opts.object, opts.image);
284 			System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score);
285 		}
286 	}
287 
288 	protected MixtureOfGaussians extract(IRecord<URL> item) {
289 		return this.gmmExtract.extractFeature(item);
290 	}
291 
292 	private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) {
293 		for (int i = 0; i < this.dataset.size(); i++) {
294 			operation.perform(this.dataset.get(i));
295 		}
296 
297 	}
298 
299 	private void applyToEachImage(Operation<ObjectRecord> operation) {
300 		for (int i = 0; i < this.dataset.size(); i++) {
301 			UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i);
302 			for (IRecord<URL> iRecord : ukBenchListDataset) {
303 				operation.perform(new ObjectRecord(i, iRecord));
304 			}
305 		}
306 	}
307 	
308 	public double score(int object, int image) {
309 		System.out.printf("Scoring Object %d, Image %d\n",object,image);
310 		IRecord<URL> item = this.dataset.get(object).get(image);
311 		final MixtureOfGaussians thisGMM = extract(item);
312 		final List<IntDoublePair> scored = new ArrayList<IntDoublePair>();
313 		applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() {
314 
315 			@Override
316 			public void perform(ObjectRecord object) {
317 				MixtureOfGaussians otherGMM = extract(object.getSecondObject());
318 				
319 				double distance = comp.compare(thisGMM, otherGMM);
320 				scored.add(IntDoublePair.pair(object.firstObject(), distance));
321 				if(scored.size() % 200 == 0){
322 					System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4));
323 				}
324 			}
325 		});
326 		
327 		Collections.sort(scored, new Comparator<IntDoublePair>(){
328 
329 			@Override
330 			public int compare(IntDoublePair o1, IntDoublePair o2) {
331 				return -Double.compare(o1.second, o2.second);
332 			}
333 			
334 		});
335 		double good = 0;
336 		for (int i = 0; i < 4; i++) {
337 			if(scored.get(i).first == object) good+=1; 
338 		}
339 		return good/4f;
340 	}
341 
342 	/**
343 	 * @return the mixture of gaussians for each group
344 	 */
345 	public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() {
346 		final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>();
347 		ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors
348 				.newFixedThreadPool(1,
349 						new DaemonThreadFactory());
350 		final double TOTAL = this.dataset.size() * 4;
351 		Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() {
352 
353 			@Override
354 			public void perform(Integer i) {
355 				groups.put(i, extractGroupGaussians(i));
356 				if(groups.size() % 200 == 0){
357 					System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL);
358 				}
359 			}
360 		}, pool);
361 
362 		return groups;
363 	}
364 
365 	public List<MixtureOfGaussians> extractGroupGaussians(int i) {
366 		return this.extractGroupGaussians(this.dataset.get(i));
367 	}
368 
369 	public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) {
370 		List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>();
371 		int i = 0;
372 		for (IRecord<URL> imageURL : ukbenchObject) {
373 			MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL);
374 			gaussians.add(gmm);
375 		}
376 		return gaussians;
377 	}
378 
379 }