001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.gmm.retrieval;
031
032import java.io.File;
033import java.io.IOException;
034import java.io.InputStream;
035import java.net.URL;
036import java.util.ArrayList;
037import java.util.Collections;
038import java.util.Comparator;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.concurrent.Executors;
043import java.util.concurrent.ThreadPoolExecutor;
044
045import org.apache.commons.vfs2.FileObject;
046import org.apache.commons.vfs2.FileSystemException;
047import org.kohsuke.args4j.CmdLineException;
048import org.kohsuke.args4j.CmdLineParser;
049import org.kohsuke.args4j.Option;
050import org.openimaj.data.identity.Identifiable;
051import org.openimaj.feature.CachingFeatureExtractor;
052import org.openimaj.feature.DiskCachingFeatureExtractor;
053import org.openimaj.feature.FeatureExtractor;
054import org.openimaj.feature.FeatureVector;
055import org.openimaj.feature.local.LocalFeature;
056import org.openimaj.feature.local.list.LocalFeatureList;
057import org.openimaj.image.FImage;
058import org.openimaj.image.ImageUtilities;
059import org.openimaj.image.processing.resize.ResizeProcessor;
060import org.openimaj.io.ObjectReader;
061import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
062import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator;
063import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType;
064import org.openimaj.util.function.Function;
065import org.openimaj.util.function.Operation;
066import org.openimaj.util.pair.IndependentPair;
067import org.openimaj.util.pair.IntDoublePair;
068import org.openimaj.util.parallel.Parallel;
069import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory;
070
071/**
072 * 
073 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
074 */
075public class UKBenchGMMExperiment {
076        private final class FImageFileObjectReader implements
077                        ObjectReader<FImage, FileObject> {
078                @Override
079                public FImage read(FileObject source) throws IOException {
080                        return ImageUtilities.FIMAGE_READER.read(source.getContent()
081                                        .getInputStream());
082                }
083
084                @Override
085                public boolean canRead(FileObject source, String name) {
086                        InputStream inputStream = null;
087                        try {
088                                inputStream = source.getContent().getInputStream();
089                                return ImageUtilities.FIMAGE_READER.canRead(inputStream, name);
090                        } catch (FileSystemException e) {
091                        } finally {
092                                if (inputStream != null) {
093                                        try {
094                                                inputStream.close();
095                                        } catch (IOException e) {
096                                                throw new RuntimeException(e);
097                                        }
098                                }
099                        }
100                        return false;
101                }
102        }
103
104        private final class URLFileObjectReader implements
105                        ObjectReader<URL, FileObject> {
106                @Override
107                public URL read(FileObject source) throws IOException {
108                        return source.getURL();
109                }
110
111                @Override
112                public boolean canRead(FileObject source, String name) {
113                        try {
114                                return (source.getURL() != null);
115                        } catch (FileSystemException e) {
116                                return false;
117                        }
118                }
119        }
120
121        private static final class IRecordWrapper<A, B> implements
122                        Function<UKBenchGMMExperiment.IRecord<A>, B> {
123                Function<A, B> inner;
124
125                public IRecordWrapper(Function<A, B> extract) {
126                        this.inner = extract;
127                }
128
129                @Override
130                public B apply(IRecord<A> in) {
131                        return inner.apply(in.image);
132                }
133
134                public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) {
135                        return new IRecordWrapper<A, B>(extract);
136                }
137        }
138
139        private static class IRecord<IMAGE> implements Identifiable {
140
141                private String id;
142                private IMAGE image;
143
144                public IRecord(String id, IMAGE image) {
145                        this.id = id;
146                        this.image = image;
147                }
148
149                @Override
150                public String getID() {
151                        return this.id;
152                }
153
154                public static <A> IRecord<A> wrap(String id, A payload) {
155                        return new IRecord<A>(id, payload);
156                }
157
158        }
159
160        private static final class IRecordReader<IMAGE> implements
161                        ObjectReader<IRecord<IMAGE>, FileObject> {
162                ObjectReader<IMAGE, FileObject> reader;
163
164                public IRecordReader(ObjectReader<IMAGE, FileObject> reader) {
165                        this.reader = reader;
166                }
167
168                @Override
169                public IRecord<IMAGE> read(FileObject source) throws IOException {
170                        String name = source.getName().getBaseName();
171                        IMAGE image = reader.read(source);
172                        return new IRecord<IMAGE>(name, image);
173                }
174
175                @Override
176                public boolean canRead(FileObject source, String name) {
177                        return reader.canRead(source, name);
178                }
179        }
180
181        private String ukbenchRoot = "/Users/ss/Experiments/ukbench";
182        private ResizeProcessor resize;
183        private UKBenchGroupDataset<IRecord<URL>> dataset;
184        private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract;
185        final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator();
186
187        public UKBenchGMMExperiment() {
188                setup();
189        }
190
191        public UKBenchGMMExperiment(String root) {
192                this.ukbenchRoot = root;
193                setup();
194        }
195
196        private void setup() {
197                this.dataset = new UKBenchGroupDataset<IRecord<URL>>(
198                                ukbenchRoot + "/full",
199                                // new IRecordReader<FImage>(new FImageFileObjectReader())
200                                new IRecordReader<URL>(new URLFileObjectReader()));
201
202                resize = new ResizeProcessor(640, 480);
203
204                Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() {
205
206                        @Override
207                        public MixtureOfGaussians apply(URL in) {
208                                
209                                final DSiftFeatureExtractor feature = new DSiftFeatureExtractor();
210                                final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal);
211                                System.out.println("... resize");
212                                FImage process = null;
213                                try {
214                                        process = ImageUtilities.readF(in).process(resize);
215                                } catch (IOException e) {
216                                        throw new RuntimeException(e);
217                                }
218                                System.out.println("... dsift");
219                                LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature
220                                                .apply(process);
221                                System.out.println("... gmm");
222                                return gmmFunc.apply(apply);
223                        }
224
225                };
226                this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
227                                new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
228                                                new File(ukbenchRoot + "/gmm/dsift"),
229                                                FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined)))
230                );
231        }
232
233        static class UKBenchGMMExperimentOptions {
234                @Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING")
235                String input = null;
236
237                @Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN")
238                boolean preextract = false;
239                
240                @Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer")
241                int object = -1;
242                
243                @Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer")
244                int image = -1;
245        }
246
247        static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> {
248
249                public ObjectRecord(Integer obj1, IRecord<URL> obj2) {
250                        super(obj1, obj2);
251                }
252
253        }
254
255        /**
256         * @param args
257         * @throws IOException
258         * @throws CmdLineException
259         */
260        public static void main(String[] args) throws IOException, CmdLineException {
261                UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions();
262                final CmdLineParser parser = new CmdLineParser(opts);
263                parser.parseArgument(args);
264                final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input);
265                if (opts.preextract){
266                        System.out.println("Preloading all ukbench features...");
267                        exp.extractGroupGaussians();                    
268                }
269                
270                if(opts.object == -1 || opts.image == -1){                      
271                        exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() {
272                                
273                                @Override
274                                public void perform(UKBenchListDataset<IRecord<URL>> group) {
275                                        int object = group.getObject();
276                                        for (int i = 0; i < group.size(); i++) {
277                                                double score = exp.score(object, i);
278                                                System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score);
279                                        }
280                                }
281                        });
282                } else {
283                        double score = exp.score(opts.object, opts.image);
284                        System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score);
285                }
286        }
287
288        protected MixtureOfGaussians extract(IRecord<URL> item) {
289                return this.gmmExtract.extractFeature(item);
290        }
291
292        private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) {
293                for (int i = 0; i < this.dataset.size(); i++) {
294                        operation.perform(this.dataset.get(i));
295                }
296
297        }
298
299        private void applyToEachImage(Operation<ObjectRecord> operation) {
300                for (int i = 0; i < this.dataset.size(); i++) {
301                        UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i);
302                        for (IRecord<URL> iRecord : ukBenchListDataset) {
303                                operation.perform(new ObjectRecord(i, iRecord));
304                        }
305                }
306        }
307        
308        public double score(int object, int image) {
309                System.out.printf("Scoring Object %d, Image %d\n",object,image);
310                IRecord<URL> item = this.dataset.get(object).get(image);
311                final MixtureOfGaussians thisGMM = extract(item);
312                final List<IntDoublePair> scored = new ArrayList<IntDoublePair>();
313                applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() {
314
315                        @Override
316                        public void perform(ObjectRecord object) {
317                                MixtureOfGaussians otherGMM = extract(object.getSecondObject());
318                                
319                                double distance = comp.compare(thisGMM, otherGMM);
320                                scored.add(IntDoublePair.pair(object.firstObject(), distance));
321                                if(scored.size() % 200 == 0){
322                                        System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4));
323                                }
324                        }
325                });
326                
327                Collections.sort(scored, new Comparator<IntDoublePair>(){
328
329                        @Override
330                        public int compare(IntDoublePair o1, IntDoublePair o2) {
331                                return -Double.compare(o1.second, o2.second);
332                        }
333                        
334                });
335                double good = 0;
336                for (int i = 0; i < 4; i++) {
337                        if(scored.get(i).first == object) good+=1; 
338                }
339                return good/4f;
340        }
341
342        /**
343         * @return the mixture of gaussians for each group
344         */
345        public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() {
346                final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>();
347                ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors
348                                .newFixedThreadPool(1,
349                                                new DaemonThreadFactory());
350                final double TOTAL = this.dataset.size() * 4;
351                Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() {
352
353                        @Override
354                        public void perform(Integer i) {
355                                groups.put(i, extractGroupGaussians(i));
356                                if(groups.size() % 200 == 0){
357                                        System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL);
358                                }
359                        }
360                }, pool);
361
362                return groups;
363        }
364
365        public List<MixtureOfGaussians> extractGroupGaussians(int i) {
366                return this.extractGroupGaussians(this.dataset.get(i));
367        }
368
369        public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) {
370                List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>();
371                int i = 0;
372                for (IRecord<URL> imageURL : ukbenchObject) {
373                        MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL);
374                        gaussians.add(gmm);
375                }
376                return gaussians;
377        }
378
379}