001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.projection;
031
032import gnu.trove.list.array.TDoubleArrayList;
033
034import org.openimaj.math.matrix.MatrixUtils;
035import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
036
037import Jama.Matrix;
038
039/**
040 * {@link LargeMarginDimensionalityReduction} is a technique to compress high
041 * dimensional features into a lower-dimension representation using a learned
042 * linear projection. Supervised learning is used to learn the projection such
043 * that the squared Euclidean distance between two low-dimensional vectors is
044 * less than a threshold if they correspond to the same object, or greater
045 * otherwise. In addition, it is imposed that this condition is satisfied with a
046 * margin of at least one.
047 * <p>
048 * In essence, the Euclidean distance in the low dimensional space produced by
049 * this technique can be seen as a low-rank Mahalanobis metric in the original
050 * space; the Mahalanobis matrix would have rank equal to the number of
051 * dimensions of the smaller space.
052 * <p>
053 * This class implements the technique using stochastic sub-gradient descent. As
054 * the objective function is not convex, initialisation is important, and
055 * initial conditions are generated by selecting the largest PCA dimensions, and
056 * then whitening the dimensions so they have equal magnitude. In addition, the
057 * projection matrix is not regularised explicitly; instead the algorithm is
058 * just stopped after a fixed number of iterations.
059 * 
060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
061 */
062public class LargeMarginDimensionalityReduction {
063        protected int ndims;
064        protected double wLearnRate = 0.25; // gamma
065        protected double bLearnRate = 1; // bias gamma
066        protected Matrix W;
067        protected double b;
068
069        /**
070         * Construct with the given target dimensionality and default values for the
071         * other parameters (learning rate of 1.0 for W; learning rate of 0 for
072         * bias).
073         * 
074         * @param ndims
075         *            target number of dimensions
076         */
077        public LargeMarginDimensionalityReduction(int ndims) {
078                this.ndims = ndims;
079        }
080
081        /**
082         * Construct with the given target dimensionality learning rates.
083         * 
084         * @param ndims
085         *            target number of dimensions
086         * @param wLearnRate
087         *            learning rate for the transform matrix, W.
088         * @param bLearnRate
089         *            learning rate for the bias, b.
090         */
091        public LargeMarginDimensionalityReduction(int ndims, double wLearnRate, double bLearnRate) {
092                this.ndims = ndims;
093        }
094
095        /**
096         * Initialise the LMDR with the given data in three parallel data arrays.
097         * The first two arrays together store pairs of vectors, and the third array
098         * indicate whether the vectors in each pair represent the same class or
099         * different classes. All three arrays should obviously have the same
100         * length, and approximately equal numbers of true and false pairs.
101         * <p>
102         * Internally, this method gathers together all the vectors and performs
103         * {@link ThinSvdPrincipalComponentAnalysis} to get a starting estimate for
104         * the transform matrix. The distance in the projected space between all
105         * true pairs and false pairs is computed and used to select an optimal
106         * initial threshold.
107         * 
108         * @param datai
109         *            array of the first vectors of the pairs
110         * @param dataj
111         *            array of the second vectors of the pairs
112         * @param same
113         *            array indicating where pairs are true (from the same class) or
114         *            false (from different classes)
115         */
116        public void initialise(double[][] datai, double[][] dataj, boolean[] same) {
117                final double[][] data = new double[2 * datai.length][];
118                for (int i = 0; i < datai.length; i++) {
119                        data[2 * i] = datai[i];
120                        data[2 * i + 1] = dataj[i];
121                }
122
123                final ThinSvdPrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(ndims);
124                pca.learnBasis(data);
125
126                final double[] evs = pca.getEigenValues();
127                final double[] invStdDev = new double[ndims];
128                for (int i = 0; i < ndims; i++)
129                        invStdDev[i] = 1.0 / Math.sqrt(evs[i]);
130
131                W = MatrixUtils.diag(invStdDev).times(pca.getBasis().transpose());
132
133                recomputeBias(datai, dataj, same);
134        }
135
136        public void recomputeBias(double[][] datai, double[][] dataj, boolean[] same) {
137                final TDoubleArrayList posDistances = new TDoubleArrayList();
138                final TDoubleArrayList negDistances = new TDoubleArrayList();
139                for (int i = 0; i < datai.length; i++) {
140                        final Matrix diff = diff(datai[i], dataj[i]);
141                        final Matrix diffProj = W.times(diff);
142                        final double dist = sumsq(diffProj);
143
144                        if (same[i]) {
145                                posDistances.add(dist);
146                        } else {
147                                negDistances.add(dist);
148                        }
149                }
150
151                b = computeOptimal(posDistances, negDistances);
152        }
153
154        private double computeOptimal(TDoubleArrayList posDistances, TDoubleArrayList negDistances) {
155                double bestAcc = 0;
156                double bestThresh = -Double.MAX_VALUE;
157                for (int i = 0; i < posDistances.size(); i++) {
158                        final double thresh = posDistances.get(i);
159
160                        final double acc = computeAccuracy(posDistances, negDistances, thresh);
161
162                        if (acc > bestAcc) {
163                                bestAcc = acc;
164                                bestThresh = thresh;
165                        }
166                }
167
168                for (int i = 0; i < negDistances.size(); i++) {
169                        final double thresh = negDistances.get(i);
170
171                        final double acc = computeAccuracy(posDistances, negDistances, thresh);
172
173                        if (acc > bestAcc) {
174                                bestAcc = acc;
175                                bestThresh = thresh;
176                        }
177                }
178
179                return bestThresh;
180        }
181
182        private double computeAccuracy(TDoubleArrayList posDistances, TDoubleArrayList negDistances, double thresh) {
183                int correct = 0;
184                for (int i = 0; i < posDistances.size(); i++) {
185                        if (posDistances.get(i) < thresh)
186                                correct++;
187                }
188
189                for (int i = 0; i < negDistances.size(); i++) {
190                        if (negDistances.get(i) >= thresh)
191                                correct++;
192                }
193
194                return (double) correct / (double) (posDistances.size() + negDistances.size());
195        }
196
197        private Matrix diff(double[] phii, double[] phij) {
198                final Matrix diff = new Matrix(phii.length, 1);
199                final double[][] diffv = diff.getArray();
200
201                for (int i = 0; i < phii.length; i++) {
202                        diffv[i][0] = phii[i] - phij[i];
203                }
204                return diff;
205        }
206
207        private double sumsq(Matrix diffProj) {
208                final double[][] v = diffProj.getArray();
209
210                double sumsq = 0;
211                for (int i = 0; i < v.length; i++) {
212                        sumsq += v[i][0] * v[i][0];
213                }
214
215                return sumsq;
216        }
217
218        /**
219         * Perform a single update step of the SGD optimisation. Alternate calls to
220         * this method should swap between true and false pairs.
221         * 
222         * @param phii
223         *            first vector
224         * @param phij
225         *            second vector
226         * @param same
227         *            true if the vectors are from the same class; false otherwise
228         * @return true if the transform matrix was changed; false otherwise
229         */
230        public boolean step(double[] phii, double[] phij, boolean same) {
231                final int yij = same ? 1 : -1;
232
233                final Matrix diff = diff(phii, phij);
234                final Matrix diffProj = W.times(diff);
235                final double sumsq = sumsq(diffProj);
236
237                if (yij * (b - sumsq) > 1)
238                        return false;
239
240                // final Matrix updateW = diffProj.times(wLearnRate *
241                // yij).times(diff.transpose());
242                // W.minusEquals(updateW);
243                fastUpdate(diffProj, wLearnRate * yij, diff);
244
245                b += yij * bLearnRate;
246
247                return true;
248        }
249
250        /**
251         * This efficiently computes the update in place without creating loads of
252         * temporary matrices, and does so in a single pass!
253         * 
254         * @param diffProj
255         * @param weight
256         * @param diff
257         */
258        private void fastUpdate(Matrix diffProj, double weight, Matrix diff) {
259                // final Matrix updateW = diffProj.times(wLearnRate *
260                // yij).times(diff.transpose());
261                // W.minusEquals(updateW);
262
263                final double[][] dp = diffProj.getArray();
264                final double[][] d = diff.getArray();
265                final double[][] Wdata = W.getArray();
266                for (int r = 0; r < Wdata.length; r++) {
267                        for (int c = 0; c < Wdata.length; c++) {
268                                Wdata[r][c] -= weight * dp[r][0] * d[c][0];
269                        }
270                }
271        }
272
273        /**
274         * Get the transform matrix W
275         * 
276         * @return the transform matrix
277         */
278        public Matrix getTransform() {
279                return W;
280        }
281
282        /**
283         * Get the bias, b
284         * 
285         * @return the bias
286         */
287        public double getBias() {
288                return b;
289        }
290
291        /**
292         * Compute the matching score between a pair of (high dimensional) features.
293         * Scores >=0 indicate a matching pair.
294         * 
295         * @param phii
296         *            first vector
297         * @param phij
298         *            second vector
299         * @return the matching score
300         */
301        public double score(double[] phii, double[] phij) {
302                final Matrix diff = diff(phii, phij);
303                final Matrix diffProj = W.times(diff);
304
305                return b - sumsq(diffProj);
306        }
307
308        /**
309         * Determine if two features are from the same class or different classes.
310         * 
311         * @param phii
312         *            first vector
313         * @param phij
314         *            second vector
315         * @return the classification (true if same class; false if different)
316         */
317        public boolean classify(double[] phii, double[] phij) {
318                return score(phii, phij) >= 0;
319        }
320
321        /**
322         * Compute the low rank estimate of the given vector
323         * 
324         * @param in
325         *            the vector
326         * @return the low-rank projection of the vector
327         */
328        public double[] project(double[] in) {
329                return W.times(new Matrix(new double[][] { in }).transpose()).getColumnPackedCopy();
330        }
331
332        public void setBias(double d) {
333                this.b = d;
334        }
335
336        public void setTransform(Matrix proj) {
337                this.W = proj;
338        }
339}