001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.sandbox.image.vlad;
031
032import java.io.DataInputStream;
033import java.io.EOFException;
034import java.io.File;
035import java.io.FileInputStream;
036import java.io.IOException;
037import java.nio.ByteBuffer;
038import java.nio.ByteOrder;
039import java.util.ArrayList;
040import java.util.Comparator;
041import java.util.HashMap;
042import java.util.HashSet;
043import java.util.List;
044import java.util.Map;
045import java.util.Map.Entry;
046import java.util.Set;
047
048import org.openimaj.data.identity.Identifiable;
049import org.openimaj.experiment.evaluation.retrieval.RetrievalEvaluator;
050import org.openimaj.experiment.evaluation.retrieval.analysers.IREvalAnalyser;
051import org.openimaj.experiment.evaluation.retrieval.analysers.IREvalResult;
052import org.openimaj.feature.FVComparator;
053import org.openimaj.feature.FeatureVector;
054import org.openimaj.feature.FloatFV;
055import org.openimaj.feature.FloatFVComparison;
056import org.openimaj.feature.MultidimensionalFloatFV;
057import org.openimaj.feature.SparseFloatFV;
058import org.openimaj.feature.SparseFloatFVComparison;
059import org.openimaj.feature.SparseIntFV;
060import org.openimaj.feature.local.FloatLocalFeatureAdaptor;
061import org.openimaj.feature.local.list.LocalFeatureList;
062import org.openimaj.image.feature.local.aggregate.BagOfVisualWords;
063import org.openimaj.image.feature.local.aggregate.VLAD;
064import org.openimaj.image.feature.local.keypoints.SIFTGeoKeypoint;
065import org.openimaj.image.feature.local.keypoints.SIFTGeoKeypoint.SIFTGeoLocation;
066import org.openimaj.ml.clustering.FloatCentroidsResult;
067import org.openimaj.ml.clustering.assignment.hard.ExactFloatAssigner;
068import org.openimaj.util.pair.DoubleObjectPair;
069import org.openimaj.util.queue.BoundedPriorityQueue;
070
071public class VLADHolidays {
072        static class Document implements Identifiable {
073                String name;
074
075                public Document(String name) {
076                        this.name = name;
077                }
078
079                @Override
080                public String getID() {
081                        return name;
082                }
083
084                @Override
085                public int hashCode() {
086                        return name.hashCode();
087                }
088
089                @Override
090                public boolean equals(Object obj) {
091                        if (obj instanceof Document)
092                                return name.equals(((Document) obj).name);
093                        return false;
094                }
095
096                @Override
097                public String toString() {
098                        return name;
099                }
100        }
101
102        @SuppressWarnings("unchecked")
103        public static <T extends FeatureVector> void main(String[] args) throws IOException {
104                final Map<Document, T> data;
105                final FVComparator<T> comp;
106
107                final boolean vladMode = false;
108                if (vladMode) {
109                        data = (Map<Document, T>) getVLADFeatures();
110                        comp = (FVComparator<T>) FloatFVComparison.EUCLIDEAN;
111                } else {
112                        data = (Map<Document, T>) getBoVWFeatures();
113                        comp = (FVComparator<T>) SparseFloatFVComparison.EUCLIDEAN;
114                }
115
116                // Perform experiment
117                final Map<Integer, Set<Document>> groundTruth = new HashMap<Integer, Set<Document>>();
118                final Map<Integer, List<Document>> ret = new HashMap<Integer, List<Document>>();
119                for (final Document k : data.keySet()) {
120                        final int i = Integer.parseInt(k.name.replace(".siftgeo", ""));
121                        final int q = i - (i % 100);
122
123                        if (q != i) {
124                                if (!groundTruth.containsKey(q))
125                                        groundTruth.put(q, new HashSet<Document>());
126
127                                groundTruth.get(q).add(k);
128                        } else {
129                                final List<DoubleObjectPair<Document>> res = search(k, data, comp, 1000);
130                                ret.put(q, DoubleObjectPair.getSecond(res));
131                        }
132                }
133
134                final RetrievalEvaluator<IREvalResult, Document, Integer> eval = new RetrievalEvaluator<IREvalResult, VLADHolidays.Document, Integer>(
135                                ret, groundTruth, new IREvalAnalyser<Integer, Document>());
136
137                final Map<Integer, List<Document>> evalRes = eval.evaluate();
138                final IREvalResult finalRes = eval.analyse(evalRes);
139
140                System.out.println(finalRes.getSummaryReport());
141        }
142
143        /**
144         * Load the raw features and create VLAD representations.
145         * 
146         * @return
147         * @throws IOException
148         */
149        private static Map<Document, FloatFV> getVLADFeatures() throws IOException {
150                final Map<Document, FloatFV> vladData = new HashMap<Document, FloatFV>();
151
152                final FloatCentroidsResult centroids = readFvecs(new File(
153                                "/Users/jsh2/Downloads/cvpr2010/data/clust_k64.fvecs"));
154
155                final ExactFloatAssigner assigner = new ExactFloatAssigner(centroids);
156
157                final VLAD<float[]> vlad = new VLAD<float[]>(assigner, centroids, true);
158
159                for (final File f : new File("/Users/jsh2/Downloads/siftgeo/").listFiles()) {
160                        System.out.println("Loading " + f.getName());
161                        final LocalFeatureList<SIFTGeoKeypoint> keys = SIFTGeoKeypoint.read(f);
162
163                        final List<FloatLocalFeatureAdaptor<SIFTGeoLocation>> fkeys = FloatLocalFeatureAdaptor.wrap(keys);
164
165                        final MultidimensionalFloatFV fv = vlad.aggregate(fkeys);
166
167                        vladData.put(new Document(f.getName()), fv);
168                }
169
170                return vladData;
171        }
172
173        /**
174         * Load the raw features and create BoVW representations.
175         * 
176         * @return
177         * @throws IOException
178         */
179        protected static Map<Document, SparseFloatFV> getBoVWFeatures() throws IOException {
180                final Map<Document, SparseFloatFV> vladData = new HashMap<Document, SparseFloatFV>();
181
182                final FloatCentroidsResult centroids = readFvecs(new File(
183                                "/Users/jsh2/Downloads/clust/clust_flickr60_k20000.fvecs"));
184                // "/Users/jsh2/Downloads/cvpr2010/data/clust_k64.fvecs"));
185
186                final ExactFloatAssigner assigner = new ExactFloatAssigner(centroids);
187                // final ApproximateFloatEuclideanAssigner assigner = new
188                // ApproximateFloatEuclideanAssigner(centroids);
189
190                final BagOfVisualWords<float[]> bovw = new BagOfVisualWords<float[]>(assigner);
191
192                int N = 0;
193                final float[] n = new float[assigner.size()];
194
195                for (final File f : new File("/Users/jsh2/Downloads/siftgeo/").listFiles()) {
196                        N++;
197                        System.out.println("Loading " + f.getName());
198                        final LocalFeatureList<SIFTGeoKeypoint> keys = SIFTGeoKeypoint.read(f);
199
200                        final List<FloatLocalFeatureAdaptor<SIFTGeoLocation>> fkeys = FloatLocalFeatureAdaptor.wrap(keys);
201
202                        final SparseIntFV fv = bovw.aggregate(fkeys);
203
204                        final SparseFloatFV fv2 = new SparseFloatFV(fv.length());
205                        float sum = 0;
206                        for (final org.openimaj.util.array.SparseIntArray.Entry i : fv.values.entries()) {
207                                sum += (i.value * i.value);
208                                fv2.values.set(i.index, i.value);
209                                n[i.index] += i.value;
210                        }
211                        sum = (float) Math.sqrt(sum);
212                        for (final org.openimaj.util.array.SparseFloatArray.Entry i : fv2.values.entries()) {
213                                fv2.values.set(i.index, i.value / sum);
214                        }
215
216                        vladData.put(new Document(f.getName()), fv2);
217                }
218
219                for (int i = 0; i < n.length; i++) {
220                        n[i] = (float) Math.log(N / n[i]);
221                }
222
223                for (final SparseFloatFV fv : vladData.values()) {
224                        for (final org.openimaj.util.array.SparseFloatArray.Entry i : fv.values.entries()) {
225                                fv.values.set(i.index, i.value * n[i.index]);
226                        }
227                }
228
229                return vladData;
230        }
231
232        /**
233         * The search function. This computes the distance between the query and
234         * every other document, and stores the best results. The query document is
235         * omitted from the results list.
236         * 
237         * @param queryDoc
238         *            the query identifier (assumed to be in the set of docs to be
239         *            searched)
240         * @param features
241         *            the features to search
242         * @param comp
243         *            the comparator to use
244         * @param limit
245         *            the number of top-matching docs to retain
246         * @return the ranked list of results
247         */
248        private static <T extends FeatureVector> List<DoubleObjectPair<Document>> search(Document queryDoc,
249                        Map<Document, T> features, FVComparator<T> comp, int limit)
250        {
251                final BoundedPriorityQueue<DoubleObjectPair<Document>> queue = new BoundedPriorityQueue<DoubleObjectPair<Document>>(
252                                limit,
253                                new Comparator<DoubleObjectPair<Document>>() {
254
255                                        @Override
256                                        public int compare(DoubleObjectPair<Document> o1, DoubleObjectPair<Document> o2) {
257                                                return Double.compare(o1.first, o2.first);
258                                        }
259                                });
260
261                final T query = features.get(queryDoc);
262                if (query != null) {
263                        for (final Entry<Document, T> e : features.entrySet()) {
264                                if (e.getValue() == query || e.getValue() == null)
265                                        continue;
266
267                                final T that = e.getValue();
268                                final double distance = comp.compare(query, that);
269
270                                queue.add(new DoubleObjectPair<Document>(distance, e.getKey()));
271                        }
272                }
273
274                return queue.toOrderedListDestructive();
275        }
276
277        /**
278         * Function to read an fvecs file. Because the local features are stored as
279         * signed bytes (-127..128) we offset the elements by -128 to make the
280         * ranges compatible.
281         * 
282         * @param file
283         *            the .fvecs file to read
284         * @return the centroids from the .fvecs file
285         * @throws IOException
286         */
287        private static FloatCentroidsResult readFvecs(File file) throws IOException {
288                final DataInputStream dis = new DataInputStream(new FileInputStream(file));
289
290                final List<float[]> data = new ArrayList<float[]>();
291                final byte[] tmpArray = new byte[516];
292                final ByteBuffer buffer = ByteBuffer.wrap(tmpArray);
293                buffer.order(ByteOrder.LITTLE_ENDIAN);
294
295                while (true) {
296                        try {
297                                dis.readFully(tmpArray);
298                                buffer.rewind();
299
300                                if (buffer.getInt() != 128) {
301                                        throw new IOException("Unexpected length");
302                                }
303
304                                final float[] array = new float[128];
305                                for (int i = 0; i < 128; i++) {
306                                        array[i] = buffer.getFloat() - 128;
307                                }
308
309                                data.add(array);
310                        } catch (final EOFException e) {
311                                final FloatCentroidsResult f = new FloatCentroidsResult();
312                                f.centroids = data.toArray(new float[data.size()][]);
313                                dis.close();
314                                return f;
315                        }
316                }
317        }
318}