001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.gmm.retrieval; 031 032import java.io.File; 033import java.io.IOException; 034import java.io.InputStream; 035import java.net.URL; 036import java.util.ArrayList; 037import java.util.Collections; 038import java.util.Comparator; 039import java.util.HashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.concurrent.Executors; 043import java.util.concurrent.ThreadPoolExecutor; 044 045import org.apache.commons.vfs2.FileObject; 046import org.apache.commons.vfs2.FileSystemException; 047import org.kohsuke.args4j.CmdLineException; 048import org.kohsuke.args4j.CmdLineParser; 049import org.kohsuke.args4j.Option; 050import org.openimaj.data.identity.Identifiable; 051import org.openimaj.feature.CachingFeatureExtractor; 052import org.openimaj.feature.DiskCachingFeatureExtractor; 053import org.openimaj.feature.FeatureExtractor; 054import org.openimaj.feature.FeatureVector; 055import org.openimaj.feature.local.LocalFeature; 056import org.openimaj.feature.local.list.LocalFeatureList; 057import org.openimaj.image.FImage; 058import org.openimaj.image.ImageUtilities; 059import org.openimaj.image.processing.resize.ResizeProcessor; 060import org.openimaj.io.ObjectReader; 061import org.openimaj.math.statistics.distribution.MixtureOfGaussians; 062import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator; 063import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType; 064import org.openimaj.util.function.Function; 065import org.openimaj.util.function.Operation; 066import org.openimaj.util.pair.IndependentPair; 067import org.openimaj.util.pair.IntDoublePair; 068import org.openimaj.util.parallel.Parallel; 069import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory; 070 071/** 072 * 073 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 074 */ 075public class UKBenchGMMExperiment { 076 private final class FImageFileObjectReader implements 077 ObjectReader<FImage, FileObject> { 078 @Override 079 public FImage read(FileObject source) throws IOException { 080 return ImageUtilities.FIMAGE_READER.read(source.getContent() 081 .getInputStream()); 082 } 083 084 @Override 085 public boolean canRead(FileObject source, String name) { 086 InputStream inputStream = null; 087 try { 088 inputStream = source.getContent().getInputStream(); 089 return ImageUtilities.FIMAGE_READER.canRead(inputStream, name); 090 } catch (FileSystemException e) { 091 } finally { 092 if (inputStream != null) { 093 try { 094 inputStream.close(); 095 } catch (IOException e) { 096 throw new RuntimeException(e); 097 } 098 } 099 } 100 return false; 101 } 102 } 103 104 private final class URLFileObjectReader implements 105 ObjectReader<URL, FileObject> { 106 @Override 107 public URL read(FileObject source) throws IOException { 108 return source.getURL(); 109 } 110 111 @Override 112 public boolean canRead(FileObject source, String name) { 113 try { 114 return (source.getURL() != null); 115 } catch (FileSystemException e) { 116 return false; 117 } 118 } 119 } 120 121 private static final class IRecordWrapper<A, B> implements 122 Function<UKBenchGMMExperiment.IRecord<A>, B> { 123 Function<A, B> inner; 124 125 public IRecordWrapper(Function<A, B> extract) { 126 this.inner = extract; 127 } 128 129 @Override 130 public B apply(IRecord<A> in) { 131 return inner.apply(in.image); 132 } 133 134 public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) { 135 return new IRecordWrapper<A, B>(extract); 136 } 137 } 138 139 private static class IRecord<IMAGE> implements Identifiable { 140 141 private String id; 142 private IMAGE image; 143 144 public IRecord(String id, IMAGE image) { 145 this.id = id; 146 this.image = image; 147 } 148 149 @Override 150 public String getID() { 151 return this.id; 152 } 153 154 public static <A> IRecord<A> wrap(String id, A payload) { 155 return new IRecord<A>(id, payload); 156 } 157 158 } 159 160 private static final class IRecordReader<IMAGE> implements 161 ObjectReader<IRecord<IMAGE>, FileObject> { 162 ObjectReader<IMAGE, FileObject> reader; 163 164 public IRecordReader(ObjectReader<IMAGE, FileObject> reader) { 165 this.reader = reader; 166 } 167 168 @Override 169 public IRecord<IMAGE> read(FileObject source) throws IOException { 170 String name = source.getName().getBaseName(); 171 IMAGE image = reader.read(source); 172 return new IRecord<IMAGE>(name, image); 173 } 174 175 @Override 176 public boolean canRead(FileObject source, String name) { 177 return reader.canRead(source, name); 178 } 179 } 180 181 private String ukbenchRoot = "/Users/ss/Experiments/ukbench"; 182 private ResizeProcessor resize; 183 private UKBenchGroupDataset<IRecord<URL>> dataset; 184 private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract; 185 final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator(); 186 187 public UKBenchGMMExperiment() { 188 setup(); 189 } 190 191 public UKBenchGMMExperiment(String root) { 192 this.ukbenchRoot = root; 193 setup(); 194 } 195 196 private void setup() { 197 this.dataset = new UKBenchGroupDataset<IRecord<URL>>( 198 ukbenchRoot + "/full", 199 // new IRecordReader<FImage>(new FImageFileObjectReader()) 200 new IRecordReader<URL>(new URLFileObjectReader())); 201 202 resize = new ResizeProcessor(640, 480); 203 204 Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() { 205 206 @Override 207 public MixtureOfGaussians apply(URL in) { 208 209 final DSiftFeatureExtractor feature = new DSiftFeatureExtractor(); 210 final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal); 211 System.out.println("... resize"); 212 FImage process = null; 213 try { 214 process = ImageUtilities.readF(in).process(resize); 215 } catch (IOException e) { 216 throw new RuntimeException(e); 217 } 218 System.out.println("... dsift"); 219 LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature 220 .apply(process); 221 System.out.println("... gmm"); 222 return gmmFunc.apply(apply); 223 } 224 225 }; 226 this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>( 227 new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>( 228 new File(ukbenchRoot + "/gmm/dsift"), 229 FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined))) 230 ); 231 } 232 233 static class UKBenchGMMExperimentOptions { 234 @Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING") 235 String input = null; 236 237 @Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN") 238 boolean preextract = false; 239 240 @Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer") 241 int object = -1; 242 243 @Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer") 244 int image = -1; 245 } 246 247 static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> { 248 249 public ObjectRecord(Integer obj1, IRecord<URL> obj2) { 250 super(obj1, obj2); 251 } 252 253 } 254 255 /** 256 * @param args 257 * @throws IOException 258 * @throws CmdLineException 259 */ 260 public static void main(String[] args) throws IOException, CmdLineException { 261 UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions(); 262 final CmdLineParser parser = new CmdLineParser(opts); 263 parser.parseArgument(args); 264 final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input); 265 if (opts.preextract){ 266 System.out.println("Preloading all ukbench features..."); 267 exp.extractGroupGaussians(); 268 } 269 270 if(opts.object == -1 || opts.image == -1){ 271 exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() { 272 273 @Override 274 public void perform(UKBenchListDataset<IRecord<URL>> group) { 275 int object = group.getObject(); 276 for (int i = 0; i < group.size(); i++) { 277 double score = exp.score(object, i); 278 System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score); 279 } 280 } 281 }); 282 } else { 283 double score = exp.score(opts.object, opts.image); 284 System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score); 285 } 286 } 287 288 protected MixtureOfGaussians extract(IRecord<URL> item) { 289 return this.gmmExtract.extractFeature(item); 290 } 291 292 private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) { 293 for (int i = 0; i < this.dataset.size(); i++) { 294 operation.perform(this.dataset.get(i)); 295 } 296 297 } 298 299 private void applyToEachImage(Operation<ObjectRecord> operation) { 300 for (int i = 0; i < this.dataset.size(); i++) { 301 UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i); 302 for (IRecord<URL> iRecord : ukBenchListDataset) { 303 operation.perform(new ObjectRecord(i, iRecord)); 304 } 305 } 306 } 307 308 public double score(int object, int image) { 309 System.out.printf("Scoring Object %d, Image %d\n",object,image); 310 IRecord<URL> item = this.dataset.get(object).get(image); 311 final MixtureOfGaussians thisGMM = extract(item); 312 final List<IntDoublePair> scored = new ArrayList<IntDoublePair>(); 313 applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() { 314 315 @Override 316 public void perform(ObjectRecord object) { 317 MixtureOfGaussians otherGMM = extract(object.getSecondObject()); 318 319 double distance = comp.compare(thisGMM, otherGMM); 320 scored.add(IntDoublePair.pair(object.firstObject(), distance)); 321 if(scored.size() % 200 == 0){ 322 System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4)); 323 } 324 } 325 }); 326 327 Collections.sort(scored, new Comparator<IntDoublePair>(){ 328 329 @Override 330 public int compare(IntDoublePair o1, IntDoublePair o2) { 331 return -Double.compare(o1.second, o2.second); 332 } 333 334 }); 335 double good = 0; 336 for (int i = 0; i < 4; i++) { 337 if(scored.get(i).first == object) good+=1; 338 } 339 return good/4f; 340 } 341 342 /** 343 * @return the mixture of gaussians for each group 344 */ 345 public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() { 346 final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>(); 347 ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors 348 .newFixedThreadPool(1, 349 new DaemonThreadFactory()); 350 final double TOTAL = this.dataset.size() * 4; 351 Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() { 352 353 @Override 354 public void perform(Integer i) { 355 groups.put(i, extractGroupGaussians(i)); 356 if(groups.size() % 200 == 0){ 357 System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL); 358 } 359 } 360 }, pool); 361 362 return groups; 363 } 364 365 public List<MixtureOfGaussians> extractGroupGaussians(int i) { 366 return this.extractGroupGaussians(this.dataset.get(i)); 367 } 368 369 public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) { 370 List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>(); 371 int i = 0; 372 for (IRecord<URL> imageURL : ukbenchObject) { 373 MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL); 374 gaussians.add(gmm); 375 } 376 return gaussians; 377 } 378 379}