001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.sandbox;
031
032import gnu.trove.map.hash.TIntObjectHashMap;
033import gnu.trove.map.hash.TObjectIntHashMap;
034import gnu.trove.procedure.TObjectIntProcedure;
035
036import java.io.File;
037import java.io.IOException;
038import java.util.ArrayList;
039import java.util.Collections;
040import java.util.Comparator;
041import java.util.HashSet;
042import java.util.List;
043import java.util.Set;
044
045import org.jgrapht.alg.ConnectivityInspector;
046import org.jgrapht.graph.DefaultWeightedEdge;
047import org.jgrapht.graph.SimpleWeightedGraph;
048import org.openimaj.feature.local.filter.ByteEntropyFilter;
049import org.openimaj.image.FImage;
050import org.openimaj.image.ImageUtilities;
051import org.openimaj.image.feature.local.engine.DoGSIFTEngine;
052import org.openimaj.image.feature.local.keypoints.Keypoint;
053import org.openimaj.image.processing.resize.ResizeProcessor;
054import org.openimaj.lsh.functions.DoubleGaussianFactory;
055import org.openimaj.lsh.sketch.IntLSHSketcher;
056import org.openimaj.util.filter.FilterUtils;
057import org.openimaj.util.function.Operation;
058import org.openimaj.util.hash.HashFunction;
059import org.openimaj.util.hash.HashFunctionFactory;
060import org.openimaj.util.hash.modifier.LSBModifier;
061import org.openimaj.util.pair.IntObjectPair;
062import org.openimaj.util.parallel.Parallel;
063
064import cern.jet.random.engine.MersenneTwister;
065
066public class HashingTest {
067        final int ndims = 128;
068        final double w = 6.0;
069        final int nbits = 128;
070        final float LOG_BASE = 0.001f;
071
072        IntLSHSketcher<double[]> sketcher;
073        List<TIntObjectHashMap<Set<String>>> database;
074        TObjectIntHashMap<String> counts = new TObjectIntHashMap<String>();
075
076        public HashingTest() {
077                final MersenneTwister rng = new MersenneTwister();
078
079                final DoubleGaussianFactory gauss = new DoubleGaussianFactory(ndims, rng, w);
080                final HashFunctionFactory<double[]> factory = new HashFunctionFactory<double[]>() {
081                        @Override
082                        public HashFunction<double[]> create() {
083                                return new LSBModifier<double[]>(gauss.create());
084                        }
085                };
086
087                sketcher = new IntLSHSketcher<double[]>(factory, nbits);
088                database = new ArrayList<TIntObjectHashMap<Set<String>>>(sketcher.arrayLength());
089
090                for (int i = 0; i < sketcher.arrayLength(); i++)
091                        database.add(new TIntObjectHashMap<Set<String>>());
092        }
093
094        static double[] logScale(byte[] v, float l) {
095                final double[] dfv = new double[v.length];
096                final double s = -Math.log(l);
097
098                for (int i = 0; i < v.length; i++) {
099                        double d = (v[i] + 128.0) / 256.0;
100
101                        if (d < l)
102                                d = l;
103                        d = (Math.log(d) + s) / s;
104                        if (d > 1.0)
105                                d = 1.0;
106
107                        dfv[i] = d;
108                }
109                return dfv;
110        }
111
112        private void indexImage(File imageFile) throws IOException {
113                final List<Keypoint> features = extractFeatures(imageFile);
114                for (final Keypoint k : features) {
115                        final int[] sketch = sketcher.createSketch(logScale(k.ivec, LOG_BASE));
116
117                        for (int i = 0; i < sketch.length; i++) {
118                                final int sk = sketch[i];
119                                synchronized (database) {
120                                        Set<String> s = database.get(i).get(sk);
121                                        if (s == null)
122                                                database.get(i).put(sk, s = new HashSet<String>());
123                                        s.add(imageFile.toString());
124                                }
125                        }
126                }
127
128                counts.put(imageFile.toString(), features.size());
129        }
130
131        List<Keypoint> extractFeatures(File imageFile) throws IOException {
132                final DoGSIFTEngine engine = new DoGSIFTEngine();
133                engine.getOptions().setDoubleInitialImage(false);
134                final ByteEntropyFilter filter = new ByteEntropyFilter();
135
136                final FImage image = ResizeProcessor.resizeMax(ImageUtilities.readF(imageFile), 150);
137
138                final List<Keypoint> features = engine.findFeatures(image);
139                return FilterUtils.filter(features, filter);
140        }
141
142        List<IntObjectPair<String>> search(File imageFile) throws IOException {
143                final TObjectIntHashMap<String> results = new TObjectIntHashMap<String>();
144
145                for (final Keypoint k : extractFeatures(imageFile)) {
146                        final int[] sketch = sketcher.createSketch(logScale(k.ivec, LOG_BASE));
147
148                        final TObjectIntHashMap<String> featResults = new TObjectIntHashMap<String>();
149
150                        for (int i = 0; i < sketch.length; i++) {
151                                final int sk = sketch[i];
152
153                                final Set<String> r = database.get(i).get(sk);
154                                if (r != null) {
155                                        for (final String file : r) {
156                                                featResults.adjustOrPutValue(file, 1, 1);
157                                                // results.adjustOrPutValue(file, 1, 1);
158                                        }
159                                }
160                        }
161
162                        featResults.forEachEntry(new TObjectIntProcedure<String>() {
163                                @Override
164                                public boolean execute(String a, int b) {
165                                        if (b >= 1)
166                                                results.adjustOrPutValue(a, b, b);
167                                        return true;
168                                }
169                        });
170                }
171
172                final List<IntObjectPair<String>> list = new ArrayList<IntObjectPair<String>>();
173
174                for (final String k : results.keys(new String[results.size()])) {
175                        list.add(new IntObjectPair<String>(results.get(k), k));
176                }
177
178                Collections.sort(list, new Comparator<IntObjectPair<String>>() {
179                        @Override
180                        public int compare(IntObjectPair<String> paramT1, IntObjectPair<String> paramT2) {
181                                final int v1 = paramT1.first;
182                                final int v2 = paramT2.first;
183
184                                if (v1 == v2)
185                                        return 0;
186                                return v1 < v2 ? 1 : 0;
187                        }
188                });
189
190                return list;
191        }
192
193        public static void main(String[] args) throws IOException {
194                final HashingTest test = new HashingTest();
195                final int nImages = 10200;
196
197                Parallel.forIndex(0, nImages, 1, new Operation<Integer>() {
198                        volatile int count = 0;
199
200                        @Override
201                        public void perform(Integer i) {
202                                try {
203                                        final File file = new File(String.format("/Users/jsh2/Data/ukbench/full/ukbench0%04d.jpg", i));
204                                        System.out.println(file);
205                                        test.indexImage(file);
206                                        count++;
207                                        System.out.println(count);
208                                } catch (final IOException e) {
209                                }
210                        }
211                });
212                System.out.println("done");
213
214                final SimpleWeightedGraph<String, DefaultWeightedEdge> graph = new SimpleWeightedGraph<String, DefaultWeightedEdge>(
215                                DefaultWeightedEdge.class);
216
217                for (int i = 0; i < nImages; i++) {
218                        final File filename = new File(String.format("/Users/jsh2/Data/ukbench/full/ukbench0%04d.jpg", i));
219
220                        graph.addVertex(filename.toString());
221                }
222
223                for (int i = 0; i < nImages; i++) {
224                        System.out.println("Query : " + i);
225                        final File filename = new File(String.format("/Users/jsh2/Data/ukbench/full/ukbench0%04d.jpg", i));
226                        final List<IntObjectPair<String>> res = test.search(filename);
227
228                        if (res.size() > 1) {
229                                for (final IntObjectPair<String> k : res) {
230                                        if (k.second.toString().equals(filename.toString()))
231                                                continue;
232
233                                        final DefaultWeightedEdge edge = graph.addEdge(filename.toString(), k.second);
234                                        if (edge != null)
235                                                graph.setEdgeWeight(edge, k.first);
236                                }
237                        }
238                }
239
240                final ConnectivityInspector<String, DefaultWeightedEdge> conn = new ConnectivityInspector<String, DefaultWeightedEdge>(
241                                graph);
242                final List<Set<String>> sets = conn.connectedSets();
243
244                for (final Set<String> s : sets) {
245                        System.out.println(s);
246                }
247        }
248}