001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.matrix.algorithm;
031
032import java.util.ArrayList;
033import java.util.HashMap;
034import java.util.List;
035import java.util.Map;
036import java.util.Map.Entry;
037
038import org.openimaj.citation.annotation.Reference;
039import org.openimaj.citation.annotation.ReferenceType;
040import org.openimaj.math.matrix.GeneralisedEigenvalueProblem;
041import org.openimaj.math.matrix.MatrixUtils;
042import org.openimaj.util.array.ArrayUtils;
043import org.openimaj.util.pair.IndependentPair;
044
045import Jama.Matrix;
046
047/**
048 * Implementation of Multiclass Linear Discriminant Analysis.
049 * 
050 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
051 *
052 */
053@Reference(
054                type = ReferenceType.Article,
055                author = { "Fisher, Ronald A." },
056                title = "{The use of multiple measurements in taxonomic problems}",
057                year = "1936",
058                journal = "Annals Eugen.",
059                pages = { "179", "", "188" },
060                volume = "7",
061                customData = {
062                        "citeulike-article-id", "764226",
063                        "keywords", "classification",
064                        "posted-at", "2006-09-18 14:06:16",
065                        "priority", "2"
066                }
067        )
068public class LinearDiscriminantAnalysis {
069        private static class MeanData {
070                double[] overallMean;
071                double[][] classMeans;
072                int numInstances;
073        }
074        
075        protected int numComponents;
076        protected Matrix eigenvectors;
077        protected double[] eigenvalues;
078        protected double[] mean;
079        
080        /**
081         * Construct with the given number of components.
082         * @param numComponents the number of components.
083         */
084        public LinearDiscriminantAnalysis(int numComponents) {
085                this.numComponents = numComponents;
086        }
087        
088        private MeanData computeMeans(List<double[][]> data) {
089                final int cols = data.get(0)[0].length;
090                final int numClasses = data.size();
091                
092                MeanData md = new MeanData();
093                md.overallMean = new double[cols];
094                md.classMeans = new double[numClasses][];
095                md.numInstances = 0;
096                
097                for (int i=0; i<numClasses; i++) {
098                        final double[][] classData = data.get(i);
099                        final int classSize = classData.length;
100                        
101                        md.classMeans[i] = computeSum(classData);
102                        md.numInstances += classSize;
103                        
104                        for (int j=0; j<cols; j++) {
105                                md.overallMean[j] += md.classMeans[i][j];
106                                md.classMeans[i][j] /= classSize;
107                        }
108                }
109                
110                for (int i=0; i<cols; i++) {
111                        md.overallMean[i] /= (double)md.numInstances;
112                }
113                
114                return md;
115        }
116        
117        private double[] computeSum(double[][] data) {
118                double[] sum = new double[data[0].length];
119                
120                for (int j=0; j<data.length; j++) {
121                        for (int i=0; i<sum.length; i++) {
122                                sum[i] += data[j][i];
123                        }
124                }
125                
126                return sum;
127        }
128        
129        /**
130         * Learn the LDA basis.
131         * @param data data grouped by class
132         */
133        public void learnBasisIP(List<? extends IndependentPair<?, double[]>> data) {
134                Map<Object, List<double[]>> mapData = new HashMap<Object, List<double[]>>();
135                
136                for (IndependentPair<?, double[]> item : data) {
137                        List<double[]> fvs = mapData.get(item.firstObject());
138                        if (fvs == null) mapData.put(item.firstObject(), fvs = new ArrayList<double[]>());
139                        
140                        
141                        fvs.add(item.getSecondObject());
142                }
143                learnBasisML(mapData);
144        }
145        
146        /**
147         * Learn the LDA basis.
148         * @param data data grouped by class
149         */
150        public void learnBasisML(Map<?, List<double[]>> data) {
151                List<double[][]> list = new ArrayList<double[][]>();
152                for (Entry<?, List<double[]>> e : data.entrySet()) {
153                        list.add(e.getValue().toArray(new double[e.getValue().size()][]));
154                }
155                learnBasis(list);
156        }
157        
158        /**
159         * Learn the LDA basis.
160         * @param data data grouped by class
161         */
162        public void learnBasisLL(List<List<double[]>> data) {
163                List<double[][]> list = new ArrayList<double[][]>();
164                for (List<double[]> e : data) {
165                        list.add(e.toArray(new double[e.size()][]));
166                }
167                learnBasis(list);
168        }
169        
170        /**
171         * Learn the LDA basis.
172         * @param data data grouped by class
173         */
174        public void learnBasis(Map<?, double[][]> data) {
175                List<double[][]> list = new ArrayList<double[][]>();
176                for (Entry<?, double[][]> e : data.entrySet()) {
177                        list.add(e.getValue());
178                }
179                learnBasis(data);
180        }
181        
182        /**
183         * Learn the LDA basis.
184         * @param data data grouped by class
185         */
186        public void learnBasis(List<double[][]> data) {
187                int c = data.size();
188                
189                if (c < 0 || numComponents >= c) 
190                        numComponents = c - 1;
191                
192                MeanData meanData = computeMeans(data);
193                mean = meanData.overallMean;
194                final double[][] classMeans = meanData.classMeans;
195                
196                final Matrix Sw = new Matrix(mean.length, mean.length);
197                final Matrix Sb = new Matrix(mean.length, mean.length);
198                
199                for (int i=0; i<c; i++) {
200                        final Matrix classData = new Matrix(data.get(i));
201                        final double[] classMean = classMeans[i];
202                        
203                        Matrix zeroCentred = MatrixUtils.minusRow(classData, classMean);
204                        MatrixUtils.plusEquals(Sw, zeroCentred.transpose().times(zeroCentred));
205                        
206                        ArrayUtils.subtract(classMean, mean);
207                        Matrix diff = new Matrix(new double[][]{ classMean });
208                        MatrixUtils.plusEquals(Sb, MatrixUtils.times(diff.transpose().times(diff), meanData.numInstances));
209                }
210                
211                IndependentPair<Matrix, double[]> evs = GeneralisedEigenvalueProblem.symmetricGeneralisedEigenvectorsSorted(Sb, Sw, numComponents);
212                this.eigenvectors = evs.firstObject();
213                this.eigenvalues = evs.secondObject();
214        }
215        
216        /**
217         * Get the basis (the LDA eigenvectors)
218         * 
219         * @return the eigenvectors
220         */
221        public Matrix getBasis() {
222                return eigenvectors;
223        }
224        
225        /**
226         * Get a specific basis vector as
227         * a double array. The returned array contains a
228         * copy of the data.
229         * 
230         * @param index the index of the vector
231         * 
232         * @return the eigenvector
233         */
234        public double[] getBasisVector(int index) {
235                double[] pc = new double[eigenvectors.getRowDimension()];
236                double[][] data = eigenvectors.getArray();
237                
238                for (int r=0; r<pc.length; r++)
239                        pc[r] = data[r][index];
240                
241                return pc;
242        }
243        
244        /**
245         * Get the basis eigenvectors. Each of column vector of the returned 
246         * matrix is an eigenvector.
247         * 
248         * Syntactic sugar for {@link #getBasis()}
249         * @return the eigenvectors
250         */
251        public Matrix getEigenVectors() {
252                return eigenvectors;
253        }
254        
255        /**
256         * @return the eigen values corresponding to the principal components
257         */
258        public double [] getEigenValues() {
259                return eigenvalues;
260        }
261        
262        /**
263         * Get the eigen value corresponding to the ith principal component.
264         * @param i the index of the component
265         * @return the eigen value corresponding to the principal component
266         */
267        public double getEigenValue(int i) {
268                return eigenvalues[i];
269        }
270        
271        /**
272         * @return The mean values
273         */
274        public double[] getMean() {
275                return mean;
276        }
277        
278        /**
279         * Generate a new "observation" as a linear combination of
280         * the eigenvectors (ev): mean + ev * scaling.
281         * <p>
282         * If the scaling vector is shorter than the number of
283         * components, it will be zero-padded. If it is longer,
284         * it will be truncated.
285         * 
286         * @param scalings the weighting for each eigenvector
287         * @return generated observation
288         */
289        public double [] generate(double[] scalings) {
290                Matrix scale = new Matrix(this.eigenvalues.length, 1);
291                
292                for (int i=0; i<Math.min(eigenvalues.length, scalings.length); i++)
293                        scale.set(i, 0, scalings[i]);
294                
295                Matrix meanMatrix = new Matrix(new double[][]{mean}).transpose();
296                
297                return meanMatrix.plus(eigenvectors.times(scale)).getColumnPackedCopy();
298        }
299        
300        /**
301         * Project a matrix of row vectors by the basis. 
302         * The vectors are normalised by subtracting the mean and
303         * then multiplied by the basis. The returned matrix
304         * has a row for each vector. 
305         * @param m the vector to project
306         * @return projected vectors
307         */
308        public Matrix project(Matrix m) {
309                Matrix vec = m.copy();
310                
311                final int rows = vec.getRowDimension();
312                final int cols = vec.getColumnDimension();
313                final double[][] vecarr = vec.getArray();
314                
315                for (int r=0; r<rows; r++)
316                        for (int c=0; c<cols; c++)
317                                vecarr[r][c] -= mean[c];
318                
319                //T = (Vt.Dt)^T == Dt.Vt
320                return vec.times(eigenvectors);
321        }
322        
323        /**
324         * Project a vector by the basis. The vector
325         * is normalised by subtracting the mean and
326         * then multiplied by the basis.
327         * @param vector the vector to project
328         * @return projected vector
329         */
330        public double[] project(double [] vector) {
331                Matrix vec = new Matrix(1, vector.length);
332                final double[][] vecarr = vec.getArray();
333                
334                for (int i=0; i<vector.length; i++)
335                        vecarr[0][i] = vector[i] - mean[i];
336                
337                return vec.times(eigenvectors).getColumnPackedCopy();
338        }
339}