001package org.openimaj.picslurper.client;
002
003import gnu.trove.map.hash.TIntObjectHashMap;
004import gnu.trove.procedure.TIntObjectProcedure;
005
006import java.io.File;
007import java.io.IOException;
008import java.util.ArrayList;
009import java.util.Collections;
010import java.util.Comparator;
011import java.util.HashSet;
012import java.util.List;
013import java.util.Set;
014import java.util.TreeSet;
015
016import org.jgrapht.alg.ConnectivityInspector;
017import org.jgrapht.graph.DefaultWeightedEdge;
018import org.jgrapht.graph.SimpleWeightedGraph;
019import org.openimaj.feature.FeatureVector;
020import org.openimaj.feature.FeatureVectorProvider;
021import org.openimaj.lsh.functions.DoubleGaussianFactory;
022import org.openimaj.lsh.sketch.IntLSHSketcher;
023import org.openimaj.picslurper.output.WriteableImageOutput;
024import org.openimaj.util.hash.HashFunction;
025import org.openimaj.util.hash.HashFunctionFactory;
026import org.openimaj.util.hash.modifier.LSBModifier;
027import org.openimaj.util.pair.LongObjectPair;
028
029import cern.jet.random.engine.MersenneTwister;
030
031/**
032 * A trend detector indexes new images and is able to tell you the n highest
033 * trending sets of near duplicate images
034 * 
035 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
036 * 
037 */
038public class TrendDetector {
039        final int ndims = 128;
040        final double w = 6.0;
041        final int nbits = 128;
042        final float LOG_BASE = 0.001f;
043
044        TrendDetectorFeatureExtractor extractor = new SIFTTrendFeatureMode();
045
046        IntLSHSketcher<double[]> sketcher;
047        List<TIntObjectHashMap<Set<WriteableImageOutput>>> database;
048        Set<LongObjectPair<WriteableImageOutputHashes>> imagesByTime = new TreeSet<LongObjectPair<WriteableImageOutputHashes>>(
049                        new Comparator<LongObjectPair<WriteableImageOutputHashes>>() {
050
051                                @Override
052                                public int compare(
053                                                LongObjectPair<WriteableImageOutputHashes> o1,
054                                                LongObjectPair<WriteableImageOutputHashes> o2)
055                                {
056                                        return ((Long) o1.first).compareTo(o2.first);
057                                }
058
059                        });
060
061        /**
062         * instantiate the LSH
063         */
064        public TrendDetector() {
065
066                this.setFeatureExtractor(new SIFTTrendFeatureMode());
067        }
068
069        static double[] logScale(double[] v, float l) {
070                final double[] dfv = new double[v.length];
071                final double s = -Math.log(l);
072
073                for (int i = 0; i < v.length; i++) {
074                        double d = (v[i] + 128.0) / 256.0;
075
076                        if (d < l)
077                                d = l;
078                        d = (Math.log(d) + s) / s;
079                        if (d > 1.0)
080                                d = 1.0;
081
082                        dfv[i] = d;
083                }
084                return dfv;
085        }
086
087        /**
088         * @param count
089         * @return the top count list of sets from all hashtable bins
090         */
091        public synchronized List<Set<WriteableImageOutput>> trending(int count) {
092                final SimpleWeightedGraph<WriteableImageOutput, DefaultWeightedEdge> graph = new SimpleWeightedGraph<WriteableImageOutput, DefaultWeightedEdge>(
093                                DefaultWeightedEdge.class);
094                for (final TIntObjectHashMap<Set<WriteableImageOutput>> set : this.database) {
095                        set.forEachEntry(new TIntObjectProcedure<Set<WriteableImageOutput>>() {
096                                @Override
097                                public boolean execute(int hashIndex, Set<WriteableImageOutput> itemSet) {
098                                        for (final WriteableImageOutput item : itemSet) {
099                                                if (!graph.containsVertex(item)) {
100                                                        graph.addVertex(item);
101                                                }
102                                        }
103                                        final List<WriteableImageOutput> itemList = new ArrayList<WriteableImageOutput>();
104                                        itemList.addAll(itemSet);
105                                        for (int i = 0; i < itemList.size(); i++) {
106                                                final WriteableImageOutput itemA = itemList.get(i);
107                                                for (int j = i + 1; j < itemList.size(); j++) {
108                                                        final WriteableImageOutput itemB = itemList.get(j);
109                                                        DefaultWeightedEdge edge = graph.getEdge(itemA, itemB);
110                                                        if (edge == null) {
111                                                                edge = graph.addEdge(itemA, itemB);
112                                                                graph.setEdgeWeight(edge, 1);
113                                                        }
114                                                        else {
115                                                                graph.setEdgeWeight(edge, graph.getEdgeWeight(edge) + 1);
116                                                        }
117
118                                                }
119                                        }
120                                        return true;
121                                }
122                        });
123                }
124
125                final Set<DefaultWeightedEdge> edges = new HashSet<DefaultWeightedEdge>(graph.edgeSet());
126                for (final DefaultWeightedEdge e : edges) {
127                        if (graph.getEdgeWeight(e) < 10)
128                                graph.removeEdge(e);
129                }
130
131                final ConnectivityInspector<WriteableImageOutput, DefaultWeightedEdge> conn = new ConnectivityInspector<WriteableImageOutput, DefaultWeightedEdge>(
132                                graph);
133                final List<Set<WriteableImageOutput>> retList = conn.connectedSets();
134                Collections.sort(retList, new Comparator<Set<WriteableImageOutput>>() {
135
136                        @Override
137                        public int compare(Set<WriteableImageOutput> o1, Set<WriteableImageOutput> o2) {
138                                return -1 * ((Integer) o1.size()).compareTo(o2.size());
139                        }
140
141                });
142                return retList.subList(0, count < retList.size() ? count : retList.size());
143        }
144
145        /**
146         * index a new image
147         * 
148         * @param instance
149         * @throws IOException
150         */
151        public synchronized void indexImage(WriteableImageOutput instance) throws IOException {
152                for (final File imageFile : instance.listImageFiles("/")) {
153                        WriteableImageOutput iclone = null;
154                        try {
155                                iclone = instance.clone();
156                                iclone.file = imageFile;
157                        } catch (final CloneNotSupportedException e) {
158
159                        }
160                        final List<? extends FeatureVectorProvider<? extends FeatureVector>> features = extractor
161                                        .extractFeatures(imageFile);
162                        final WriteableImageOutputHashes imageHashes = new WriteableImageOutputHashes(iclone);
163
164                        for (final FeatureVectorProvider<? extends FeatureVector> k : features) {
165                                double[] fv = k.getFeatureVector().asDoubleVector();
166                                if (extractor.logScale()) {
167                                        fv = logScale(fv, LOG_BASE);
168                                }
169                                final int[] sketch = sketcher.createSketch(fv);
170                                imageHashes.hashes.add(sketch);
171
172                                for (int i = 0; i < sketch.length; i++) {
173                                        final int sk = sketch[i];
174                                        synchronized (database) {
175                                                Set<WriteableImageOutput> s = database.get(i).get(sk);
176                                                if (s == null)
177                                                        database.get(i).put(sk, s = new HashSet<WriteableImageOutput>());
178
179                                                s.add(iclone);
180
181                                        }
182                                }
183                        }
184                        final long time = System.currentTimeMillis();
185                        synchronized (this.imagesByTime) {
186                                this.imagesByTime
187                                                .add(new LongObjectPair<WriteableImageOutputHashes>(
188                                                                time, imageHashes));
189                        }
190                }
191        }
192
193        public void setFeatureExtractor(TrendDetectorFeatureExtractor fe) {
194                this.extractor = fe;
195                final MersenneTwister rng = new MersenneTwister();
196
197                final DoubleGaussianFactory gauss = new DoubleGaussianFactory(fe.nDimensions(), rng, w);
198                final HashFunctionFactory<double[]> factory = new HashFunctionFactory<double[]>() {
199                        @Override
200                        public HashFunction<double[]> create() {
201                                return new LSBModifier<double[]>(gauss.create());
202                        }
203                };
204
205                sketcher = new IntLSHSketcher<double[]>(factory, nbits);
206                database = new ArrayList<TIntObjectHashMap<Set<WriteableImageOutput>>>(
207                                sketcher.arrayLength());
208
209                for (int i = 0; i < sketcher.arrayLength(); i++)
210                        database.add(new TIntObjectHashMap<Set<WriteableImageOutput>>());
211
212        }
213
214}