001package org.openimaj.picslurper.client; 002 003import gnu.trove.map.hash.TIntObjectHashMap; 004import gnu.trove.procedure.TIntObjectProcedure; 005 006import java.io.File; 007import java.io.IOException; 008import java.util.ArrayList; 009import java.util.Collections; 010import java.util.Comparator; 011import java.util.HashSet; 012import java.util.List; 013import java.util.Set; 014import java.util.TreeSet; 015 016import org.jgrapht.alg.ConnectivityInspector; 017import org.jgrapht.graph.DefaultWeightedEdge; 018import org.jgrapht.graph.SimpleWeightedGraph; 019import org.openimaj.feature.FeatureVector; 020import org.openimaj.feature.FeatureVectorProvider; 021import org.openimaj.lsh.functions.DoubleGaussianFactory; 022import org.openimaj.lsh.sketch.IntLSHSketcher; 023import org.openimaj.picslurper.output.WriteableImageOutput; 024import org.openimaj.util.hash.HashFunction; 025import org.openimaj.util.hash.HashFunctionFactory; 026import org.openimaj.util.hash.modifier.LSBModifier; 027import org.openimaj.util.pair.LongObjectPair; 028 029import cern.jet.random.engine.MersenneTwister; 030 031/** 032 * A trend detector indexes new images and is able to tell you the n highest 033 * trending sets of near duplicate images 034 * 035 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 036 * 037 */ 038public class TrendDetector { 039 final int ndims = 128; 040 final double w = 6.0; 041 final int nbits = 128; 042 final float LOG_BASE = 0.001f; 043 044 TrendDetectorFeatureExtractor extractor = new SIFTTrendFeatureMode(); 045 046 IntLSHSketcher<double[]> sketcher; 047 List<TIntObjectHashMap<Set<WriteableImageOutput>>> database; 048 Set<LongObjectPair<WriteableImageOutputHashes>> imagesByTime = new TreeSet<LongObjectPair<WriteableImageOutputHashes>>( 049 new Comparator<LongObjectPair<WriteableImageOutputHashes>>() { 050 051 @Override 052 public int compare( 053 LongObjectPair<WriteableImageOutputHashes> o1, 054 LongObjectPair<WriteableImageOutputHashes> o2) 055 { 056 return ((Long) o1.first).compareTo(o2.first); 057 } 058 059 }); 060 061 /** 062 * instantiate the LSH 063 */ 064 public TrendDetector() { 065 066 this.setFeatureExtractor(new SIFTTrendFeatureMode()); 067 } 068 069 static double[] logScale(double[] v, float l) { 070 final double[] dfv = new double[v.length]; 071 final double s = -Math.log(l); 072 073 for (int i = 0; i < v.length; i++) { 074 double d = (v[i] + 128.0) / 256.0; 075 076 if (d < l) 077 d = l; 078 d = (Math.log(d) + s) / s; 079 if (d > 1.0) 080 d = 1.0; 081 082 dfv[i] = d; 083 } 084 return dfv; 085 } 086 087 /** 088 * @param count 089 * @return the top count list of sets from all hashtable bins 090 */ 091 public synchronized List<Set<WriteableImageOutput>> trending(int count) { 092 final SimpleWeightedGraph<WriteableImageOutput, DefaultWeightedEdge> graph = new SimpleWeightedGraph<WriteableImageOutput, DefaultWeightedEdge>( 093 DefaultWeightedEdge.class); 094 for (final TIntObjectHashMap<Set<WriteableImageOutput>> set : this.database) { 095 set.forEachEntry(new TIntObjectProcedure<Set<WriteableImageOutput>>() { 096 @Override 097 public boolean execute(int hashIndex, Set<WriteableImageOutput> itemSet) { 098 for (final WriteableImageOutput item : itemSet) { 099 if (!graph.containsVertex(item)) { 100 graph.addVertex(item); 101 } 102 } 103 final List<WriteableImageOutput> itemList = new ArrayList<WriteableImageOutput>(); 104 itemList.addAll(itemSet); 105 for (int i = 0; i < itemList.size(); i++) { 106 final WriteableImageOutput itemA = itemList.get(i); 107 for (int j = i + 1; j < itemList.size(); j++) { 108 final WriteableImageOutput itemB = itemList.get(j); 109 DefaultWeightedEdge edge = graph.getEdge(itemA, itemB); 110 if (edge == null) { 111 edge = graph.addEdge(itemA, itemB); 112 graph.setEdgeWeight(edge, 1); 113 } 114 else { 115 graph.setEdgeWeight(edge, graph.getEdgeWeight(edge) + 1); 116 } 117 118 } 119 } 120 return true; 121 } 122 }); 123 } 124 125 final Set<DefaultWeightedEdge> edges = new HashSet<DefaultWeightedEdge>(graph.edgeSet()); 126 for (final DefaultWeightedEdge e : edges) { 127 if (graph.getEdgeWeight(e) < 10) 128 graph.removeEdge(e); 129 } 130 131 final ConnectivityInspector<WriteableImageOutput, DefaultWeightedEdge> conn = new ConnectivityInspector<WriteableImageOutput, DefaultWeightedEdge>( 132 graph); 133 final List<Set<WriteableImageOutput>> retList = conn.connectedSets(); 134 Collections.sort(retList, new Comparator<Set<WriteableImageOutput>>() { 135 136 @Override 137 public int compare(Set<WriteableImageOutput> o1, Set<WriteableImageOutput> o2) { 138 return -1 * ((Integer) o1.size()).compareTo(o2.size()); 139 } 140 141 }); 142 return retList.subList(0, count < retList.size() ? count : retList.size()); 143 } 144 145 /** 146 * index a new image 147 * 148 * @param instance 149 * @throws IOException 150 */ 151 public synchronized void indexImage(WriteableImageOutput instance) throws IOException { 152 for (final File imageFile : instance.listImageFiles("/")) { 153 WriteableImageOutput iclone = null; 154 try { 155 iclone = instance.clone(); 156 iclone.file = imageFile; 157 } catch (final CloneNotSupportedException e) { 158 159 } 160 final List<? extends FeatureVectorProvider<? extends FeatureVector>> features = extractor 161 .extractFeatures(imageFile); 162 final WriteableImageOutputHashes imageHashes = new WriteableImageOutputHashes(iclone); 163 164 for (final FeatureVectorProvider<? extends FeatureVector> k : features) { 165 double[] fv = k.getFeatureVector().asDoubleVector(); 166 if (extractor.logScale()) { 167 fv = logScale(fv, LOG_BASE); 168 } 169 final int[] sketch = sketcher.createSketch(fv); 170 imageHashes.hashes.add(sketch); 171 172 for (int i = 0; i < sketch.length; i++) { 173 final int sk = sketch[i]; 174 synchronized (database) { 175 Set<WriteableImageOutput> s = database.get(i).get(sk); 176 if (s == null) 177 database.get(i).put(sk, s = new HashSet<WriteableImageOutput>()); 178 179 s.add(iclone); 180 181 } 182 } 183 } 184 final long time = System.currentTimeMillis(); 185 synchronized (this.imagesByTime) { 186 this.imagesByTime 187 .add(new LongObjectPair<WriteableImageOutputHashes>( 188 time, imageHashes)); 189 } 190 } 191 } 192 193 public void setFeatureExtractor(TrendDetectorFeatureExtractor fe) { 194 this.extractor = fe; 195 final MersenneTwister rng = new MersenneTwister(); 196 197 final DoubleGaussianFactory gauss = new DoubleGaussianFactory(fe.nDimensions(), rng, w); 198 final HashFunctionFactory<double[]> factory = new HashFunctionFactory<double[]>() { 199 @Override 200 public HashFunction<double[]> create() { 201 return new LSBModifier<double[]>(gauss.create()); 202 } 203 }; 204 205 sketcher = new IntLSHSketcher<double[]>(factory, nbits); 206 database = new ArrayList<TIntObjectHashMap<Set<WriteableImageOutput>>>( 207 sketcher.arrayLength()); 208 209 for (int i = 0; i < sketcher.arrayLength(); i++) 210 database.add(new TIntObjectHashMap<Set<WriteableImageOutput>>()); 211 212 } 213 214}