001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.indexing.vlad;
031
032import java.io.File;
033import java.io.IOException;
034import java.util.ArrayList;
035import java.util.List;
036
037import org.apache.commons.math.random.MersenneTwister;
038import org.openimaj.data.RandomData;
039import org.openimaj.feature.DoubleFV;
040import org.openimaj.feature.MultidimensionalFloatFV;
041import org.openimaj.feature.local.FloatLocalFeatureAdaptor;
042import org.openimaj.feature.local.LocalFeature;
043import org.openimaj.feature.local.LocalFeatureExtractor;
044import org.openimaj.feature.local.list.MemoryLocalFeatureList;
045import org.openimaj.feature.normalisation.HellingerNormaliser;
046import org.openimaj.image.MBFImage;
047import org.openimaj.image.feature.local.aggregate.VLAD;
048import org.openimaj.knn.pq.FloatProductQuantiser;
049import org.openimaj.knn.pq.FloatProductQuantiserUtilities;
050import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
051import org.openimaj.ml.clustering.FloatCentroidsResult;
052import org.openimaj.ml.clustering.assignment.hard.ExactFloatAssigner;
053import org.openimaj.ml.clustering.kmeans.FloatKMeans;
054import org.openimaj.ml.pca.FeatureVectorPCA;
055import org.openimaj.util.array.ArrayUtils;
056import org.openimaj.util.function.Function;
057import org.openimaj.util.function.Operation;
058import org.openimaj.util.list.AcceptingListView;
059import org.openimaj.util.parallel.Parallel;
060
061import Jama.Matrix;
062
063/**
064 * Class for learning the data required to efficiently index images using VLAD
065 * with PCA and product quantisers.
066 * 
067 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
068 */
069public class VLADIndexerDataBuilder {
070        /**
071         * Feature post-processing options
072         * 
073         * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
074         * 
075         */
076        public enum StandardPostProcesses
077                        implements
078                        Function<List<? extends LocalFeature<?, ?>>, List<FloatLocalFeatureAdaptor<?>>>
079        {
080                /**
081                 * Do nothing, other than convert to float is required
082                 * 
083                 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
084                 * 
085                 */
086                NONE {
087                        @Override
088                        public List<FloatLocalFeatureAdaptor<?>> apply(List<? extends LocalFeature<?, ?>> in) {
089                                return FloatLocalFeatureAdaptor.wrapUntyped(in);
090                        }
091                },
092                /**
093                 * Apply Hellinger normalisation to the converted float features
094                 * 
095                 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
096                 * 
097                 */
098                HELLINGER {
099                        private HellingerNormaliser hell = new HellingerNormaliser(0);
100
101                        @Override
102                        public List<FloatLocalFeatureAdaptor<?>> apply(List<? extends LocalFeature<?, ?>> in) {
103                                return FloatLocalFeatureAdaptor.wrapUntyped(in, hell);
104                        }
105                }
106        }
107
108        private LocalFeatureExtractor<LocalFeature<?, ?>, MBFImage> extractor;
109        private List<File> localFeatures;
110        private boolean normalise = false;
111        private int numVladCentroids = 64;
112        private int numIterations = 100;
113        private int numPcaDims = 128;
114        private int numPqIterations = 100;
115        private int numPqAssigners = 16;
116        private float sampleProp = 0.1f;
117        private float pcaSampleProp;
118        private Function<List<? extends LocalFeature<?, ?>>, List<FloatLocalFeatureAdaptor<?>>> postProcess = StandardPostProcesses.NONE;
119
120        /**
121         * Construct a {@link VLADIndexerDataBuilder} with the given parameters
122         * 
123         * @param extractor
124         *            the local feature extractor used to generate the input
125         *            features
126         * @param localFeatures
127         *            a list of file locations of the files containing the input
128         *            local features (one per image)
129         * @param normalise
130         *            should the resultant VLAD features be l2 normalised?
131         * @param numVladCentroids
132         *            the number of centroids for VLAD (~64)
133         * @param numIterations
134         *            the number of clustering iterations (~100)
135         * @param numPcaDims
136         *            the number of dimensions to project down to using PCA (~128
137         *            for normal SIFT)
138         * @param numPqIterations
139         *            the number of iterations for clustering the product quantisers
140         *            (~100)
141         * @param numPqAssigners
142         *            the number of product quantiser assigners (~16)
143         * @param sampleProp
144         *            the proportion of features to sample for the clustering the
145         *            VLAD centroids
146         * @param pcaSampleProp
147         *            the proportion of images to sample for computing the PCA basis
148         * @param postProcess
149         *            the post-processing to apply to the raw features before input
150         *            to VLAD
151         */
152        public VLADIndexerDataBuilder(LocalFeatureExtractor<LocalFeature<?, ?>, MBFImage> extractor,
153                        List<File> localFeatures, boolean normalise, int numVladCentroids, int numIterations, int numPcaDims,
154                        int numPqIterations, int numPqAssigners, float sampleProp, float pcaSampleProp,
155                        Function<List<? extends LocalFeature<?, ?>>, List<FloatLocalFeatureAdaptor<?>>> postProcess)
156        {
157                super();
158                this.extractor = extractor;
159                this.localFeatures = localFeatures;
160                this.normalise = normalise;
161                this.numVladCentroids = numVladCentroids;
162                this.numIterations = numIterations;
163                this.numPcaDims = numPcaDims;
164                this.numPqIterations = numPqIterations;
165                this.numPqAssigners = numPqAssigners;
166                this.sampleProp = sampleProp;
167                this.pcaSampleProp = pcaSampleProp;
168                this.postProcess = postProcess == null ? StandardPostProcesses.NONE : postProcess;
169        }
170
171        /**
172         * Build the {@link VLADIndexerData} using the information provided at
173         * construction time. The following steps are taken:
174         * <p>
175         * <ol>
176         * <li>A sample of the features is loaded
177         * <li>The sample is clustered using k-means
178         * <li>VLAD representations are then built for all the input images
179         * <li>PCA is performed on the VLAD features
180         * <li>Whitening is applied to the PCA basis
181         * <li>The VLAD features are projected by the basis
182         * <li>Product quantisers are learned
183         * <li>The final {@link VLADIndexerData} object is created
184         * </ol>
185         * 
186         * @return a newly learned {@link VLADIndexerData} object
187         * @throws IOException
188         */
189        public VLADIndexerData buildIndexerData() throws IOException {
190                final VLAD<float[]> vlad = buildVLAD();
191
192                final List<MultidimensionalFloatFV> vlads = computeVLADs(vlad);
193
194                // learn PCA basis
195                System.out.println("Learning PCA basis");
196                final FeatureVectorPCA pca = new FeatureVectorPCA(new ThinSvdPrincipalComponentAnalysis(numPcaDims));
197                pca.learnBasis(vlads);
198
199                // perform whitening to balance variance; roll into pca basis
200                System.out.println("Apply random whitening to normalise variances");
201                final Matrix whitening = createRandomWhitening(numPcaDims);
202                pca.getBasis().setMatrix(0, numPcaDims - 1, 0, numPcaDims - 1, pca.getBasis().times(whitening));
203
204                // project features
205                System.out.println("Projecting with PCA");
206                final float[][] pcaVlads = projectFeatures(pca, vlads);
207
208                // learn PQs
209                System.out.println("Learning Product Quantiser Parameters");
210                final FloatProductQuantiser pq = FloatProductQuantiserUtilities.train(pcaVlads, numPqAssigners, numPqIterations);
211
212                return new VLADIndexerData(vlad, pca, pq, extractor, postProcess);
213        }
214
215        /**
216         * Build a {@link VLAD} using the information provided at construction time.
217         * The following steps are taken:
218         * <p>
219         * <ol>
220         * <li>A sample of the features is loaded
221         * <li>The sample is clustered using k-means
222         * </ol>
223         * 
224         * @return the {@link VLAD}
225         */
226        public VLAD<float[]> buildVLAD() {
227                // Load the data and normalise
228                System.out.println("Loading Data from " + localFeatures.size() + " files");
229                final List<FloatLocalFeatureAdaptor<?>> samples = loadSample();
230
231                // cluster
232                System.out.println("Clustering " + samples.size() + " Data Points");
233                final FloatCentroidsResult centroids = cluster(samples);
234
235                // build vlads
236                System.out.println("Building VLADs");
237                return new VLAD<float[]>(new ExactFloatAssigner(centroids), centroids, normalise);
238        }
239
240        private Matrix createRandomWhitening(final int ndims) {
241                final Matrix m = new Matrix(ndims, ndims);
242                final double[][] a = m.getArray();
243                final double[] norms = new double[ndims];
244
245                final MersenneTwister mt = new MersenneTwister();
246
247                for (int r = 0; r < ndims; r++) {
248                        for (int c = 0; c < ndims; c++) {
249                                a[r][c] = mt.nextGaussian();
250                                norms[r] += (a[r][c] * a[r][c]);
251                        }
252                }
253
254                for (int r = 0; r < ndims; r++) {
255                        final double norm = Math.sqrt(norms[r]);
256
257                        for (int c = 0; c < ndims; c++) {
258                                a[r][c] /= norm;
259                        }
260                }
261
262                return m;
263        }
264
265        private float[][] projectFeatures(final FeatureVectorPCA pca, List<MultidimensionalFloatFV> vlads) {
266                final List<float[]> pcaVlads = new ArrayList<float[]>();
267                Parallel.forEach(vlads, new Operation<MultidimensionalFloatFV>() {
268                        @Override
269                        public void perform(MultidimensionalFloatFV vector) {
270                                final DoubleFV result = pca.project(vector).normaliseFV(2);
271                                final float[] fresult = ArrayUtils.convertToFloat(result.values);
272
273                                synchronized (pcaVlads) {
274                                        pcaVlads.add(fresult);
275                                }
276                        }
277                });
278
279                return pcaVlads.toArray(new float[pcaVlads.size()][]);
280        }
281
282        private List<MultidimensionalFloatFV> computeVLADs(final VLAD<float[]> vlad) {
283                final List<MultidimensionalFloatFV> vlads = new ArrayList<MultidimensionalFloatFV>();
284
285                final int[] indices = RandomData.getUniqueRandomInts((int) (localFeatures.size() * pcaSampleProp), 0,
286                                localFeatures.size());
287                final List<File> selectedLocalFeatures = new AcceptingListView<File>(localFeatures, indices);
288
289                Parallel.forEach(selectedLocalFeatures, new Operation<File>() {
290                        @Override
291                        public void perform(File file) {
292                                try {
293                                        final List<FloatLocalFeatureAdaptor<?>> fkeys = readFeatures(file);
294                                        final MultidimensionalFloatFV feature = vlad.aggregate(fkeys);
295
296                                        synchronized (vlads) {
297                                                if (feature != null)
298                                                        vlads.add(feature);
299                                        }
300                                } catch (final IOException e) {
301                                        e.printStackTrace();
302                                }
303                        }
304                });
305
306                return vlads;
307        }
308
309        private List<FloatLocalFeatureAdaptor<?>> readFeatures(File file) throws IOException {
310                final List<? extends LocalFeature<?, ?>> keys = MemoryLocalFeatureList.read(file, extractor.getFeatureClass());
311
312                return postProcess.apply(keys);
313        }
314
315        private List<FloatLocalFeatureAdaptor<?>> loadSample() {
316                final List<FloatLocalFeatureAdaptor<?>> samples = new ArrayList<FloatLocalFeatureAdaptor<?>>();
317
318                Parallel.forEach(localFeatures, new Operation<File>() {
319                        @Override
320                        public void perform(File file) {
321                                try {
322                                        final List<FloatLocalFeatureAdaptor<?>> fkeys = readFeatures(file);
323
324                                        final int[] indices = RandomData.getUniqueRandomInts((int) (fkeys.size() * sampleProp), 0,
325                                                        fkeys.size());
326                                        final AcceptingListView<FloatLocalFeatureAdaptor<?>> filtered = new AcceptingListView<FloatLocalFeatureAdaptor<?>>(
327                                                        fkeys, indices);
328
329                                        synchronized (samples) {
330                                                samples.addAll(filtered);
331                                        }
332                                } catch (final IOException e) {
333                                        e.printStackTrace();
334                                }
335                        }
336                });
337
338                return samples;
339        }
340
341        private FloatCentroidsResult cluster(List<FloatLocalFeatureAdaptor<?>> rawData) {
342                // build full data array
343                final float[][] vectors = new float[rawData.size()][];
344                for (int i = 0; i < vectors.length; i++) {
345                        vectors[i] = rawData.get(i).getFeatureVector().values;
346                }
347
348                // Perform clustering
349                final FloatKMeans kmeans = FloatKMeans.createExact(numVladCentroids, numIterations);
350                final FloatCentroidsResult centroids = kmeans.cluster(vectors);
351
352                return centroids;
353        }
354}