001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.matrix.algorithm; 031 032import java.util.ArrayList; 033import java.util.HashMap; 034import java.util.List; 035import java.util.Map; 036import java.util.Map.Entry; 037 038import org.openimaj.citation.annotation.Reference; 039import org.openimaj.citation.annotation.ReferenceType; 040import org.openimaj.math.matrix.GeneralisedEigenvalueProblem; 041import org.openimaj.math.matrix.MatrixUtils; 042import org.openimaj.util.array.ArrayUtils; 043import org.openimaj.util.pair.IndependentPair; 044 045import Jama.Matrix; 046 047/** 048 * Implementation of Multiclass Linear Discriminant Analysis. 049 * 050 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 051 * 052 */ 053@Reference( 054 type = ReferenceType.Article, 055 author = { "Fisher, Ronald A." }, 056 title = "{The use of multiple measurements in taxonomic problems}", 057 year = "1936", 058 journal = "Annals Eugen.", 059 pages = { "179", "", "188" }, 060 volume = "7", 061 customData = { 062 "citeulike-article-id", "764226", 063 "keywords", "classification", 064 "posted-at", "2006-09-18 14:06:16", 065 "priority", "2" 066 } 067 ) 068public class LinearDiscriminantAnalysis { 069 private static class MeanData { 070 double[] overallMean; 071 double[][] classMeans; 072 int numInstances; 073 } 074 075 protected int numComponents; 076 protected Matrix eigenvectors; 077 protected double[] eigenvalues; 078 protected double[] mean; 079 080 /** 081 * Construct with the given number of components. 082 * @param numComponents the number of components. 083 */ 084 public LinearDiscriminantAnalysis(int numComponents) { 085 this.numComponents = numComponents; 086 } 087 088 private MeanData computeMeans(List<double[][]> data) { 089 final int cols = data.get(0)[0].length; 090 final int numClasses = data.size(); 091 092 MeanData md = new MeanData(); 093 md.overallMean = new double[cols]; 094 md.classMeans = new double[numClasses][]; 095 md.numInstances = 0; 096 097 for (int i=0; i<numClasses; i++) { 098 final double[][] classData = data.get(i); 099 final int classSize = classData.length; 100 101 md.classMeans[i] = computeSum(classData); 102 md.numInstances += classSize; 103 104 for (int j=0; j<cols; j++) { 105 md.overallMean[j] += md.classMeans[i][j]; 106 md.classMeans[i][j] /= classSize; 107 } 108 } 109 110 for (int i=0; i<cols; i++) { 111 md.overallMean[i] /= (double)md.numInstances; 112 } 113 114 return md; 115 } 116 117 private double[] computeSum(double[][] data) { 118 double[] sum = new double[data[0].length]; 119 120 for (int j=0; j<data.length; j++) { 121 for (int i=0; i<sum.length; i++) { 122 sum[i] += data[j][i]; 123 } 124 } 125 126 return sum; 127 } 128 129 /** 130 * Learn the LDA basis. 131 * @param data data grouped by class 132 */ 133 public void learnBasisIP(List<? extends IndependentPair<?, double[]>> data) { 134 Map<Object, List<double[]>> mapData = new HashMap<Object, List<double[]>>(); 135 136 for (IndependentPair<?, double[]> item : data) { 137 List<double[]> fvs = mapData.get(item.firstObject()); 138 if (fvs == null) mapData.put(item.firstObject(), fvs = new ArrayList<double[]>()); 139 140 141 fvs.add(item.getSecondObject()); 142 } 143 learnBasisML(mapData); 144 } 145 146 /** 147 * Learn the LDA basis. 148 * @param data data grouped by class 149 */ 150 public void learnBasisML(Map<?, List<double[]>> data) { 151 List<double[][]> list = new ArrayList<double[][]>(); 152 for (Entry<?, List<double[]>> e : data.entrySet()) { 153 list.add(e.getValue().toArray(new double[e.getValue().size()][])); 154 } 155 learnBasis(list); 156 } 157 158 /** 159 * Learn the LDA basis. 160 * @param data data grouped by class 161 */ 162 public void learnBasisLL(List<List<double[]>> data) { 163 List<double[][]> list = new ArrayList<double[][]>(); 164 for (List<double[]> e : data) { 165 list.add(e.toArray(new double[e.size()][])); 166 } 167 learnBasis(list); 168 } 169 170 /** 171 * Learn the LDA basis. 172 * @param data data grouped by class 173 */ 174 public void learnBasis(Map<?, double[][]> data) { 175 List<double[][]> list = new ArrayList<double[][]>(); 176 for (Entry<?, double[][]> e : data.entrySet()) { 177 list.add(e.getValue()); 178 } 179 learnBasis(data); 180 } 181 182 /** 183 * Learn the LDA basis. 184 * @param data data grouped by class 185 */ 186 public void learnBasis(List<double[][]> data) { 187 int c = data.size(); 188 189 if (c < 0 || numComponents >= c) 190 numComponents = c - 1; 191 192 MeanData meanData = computeMeans(data); 193 mean = meanData.overallMean; 194 final double[][] classMeans = meanData.classMeans; 195 196 final Matrix Sw = new Matrix(mean.length, mean.length); 197 final Matrix Sb = new Matrix(mean.length, mean.length); 198 199 for (int i=0; i<c; i++) { 200 final Matrix classData = new Matrix(data.get(i)); 201 final double[] classMean = classMeans[i]; 202 203 Matrix zeroCentred = MatrixUtils.minusRow(classData, classMean); 204 MatrixUtils.plusEquals(Sw, zeroCentred.transpose().times(zeroCentred)); 205 206 ArrayUtils.subtract(classMean, mean); 207 Matrix diff = new Matrix(new double[][]{ classMean }); 208 MatrixUtils.plusEquals(Sb, MatrixUtils.times(diff.transpose().times(diff), meanData.numInstances)); 209 } 210 211 IndependentPair<Matrix, double[]> evs = GeneralisedEigenvalueProblem.symmetricGeneralisedEigenvectorsSorted(Sb, Sw, numComponents); 212 this.eigenvectors = evs.firstObject(); 213 this.eigenvalues = evs.secondObject(); 214 } 215 216 /** 217 * Get the basis (the LDA eigenvectors) 218 * 219 * @return the eigenvectors 220 */ 221 public Matrix getBasis() { 222 return eigenvectors; 223 } 224 225 /** 226 * Get a specific basis vector as 227 * a double array. The returned array contains a 228 * copy of the data. 229 * 230 * @param index the index of the vector 231 * 232 * @return the eigenvector 233 */ 234 public double[] getBasisVector(int index) { 235 double[] pc = new double[eigenvectors.getRowDimension()]; 236 double[][] data = eigenvectors.getArray(); 237 238 for (int r=0; r<pc.length; r++) 239 pc[r] = data[r][index]; 240 241 return pc; 242 } 243 244 /** 245 * Get the basis eigenvectors. Each of column vector of the returned 246 * matrix is an eigenvector. 247 * 248 * Syntactic sugar for {@link #getBasis()} 249 * @return the eigenvectors 250 */ 251 public Matrix getEigenVectors() { 252 return eigenvectors; 253 } 254 255 /** 256 * @return the eigen values corresponding to the principal components 257 */ 258 public double [] getEigenValues() { 259 return eigenvalues; 260 } 261 262 /** 263 * Get the eigen value corresponding to the ith principal component. 264 * @param i the index of the component 265 * @return the eigen value corresponding to the principal component 266 */ 267 public double getEigenValue(int i) { 268 return eigenvalues[i]; 269 } 270 271 /** 272 * @return The mean values 273 */ 274 public double[] getMean() { 275 return mean; 276 } 277 278 /** 279 * Generate a new "observation" as a linear combination of 280 * the eigenvectors (ev): mean + ev * scaling. 281 * <p> 282 * If the scaling vector is shorter than the number of 283 * components, it will be zero-padded. If it is longer, 284 * it will be truncated. 285 * 286 * @param scalings the weighting for each eigenvector 287 * @return generated observation 288 */ 289 public double [] generate(double[] scalings) { 290 Matrix scale = new Matrix(this.eigenvalues.length, 1); 291 292 for (int i=0; i<Math.min(eigenvalues.length, scalings.length); i++) 293 scale.set(i, 0, scalings[i]); 294 295 Matrix meanMatrix = new Matrix(new double[][]{mean}).transpose(); 296 297 return meanMatrix.plus(eigenvectors.times(scale)).getColumnPackedCopy(); 298 } 299 300 /** 301 * Project a matrix of row vectors by the basis. 302 * The vectors are normalised by subtracting the mean and 303 * then multiplied by the basis. The returned matrix 304 * has a row for each vector. 305 * @param m the vector to project 306 * @return projected vectors 307 */ 308 public Matrix project(Matrix m) { 309 Matrix vec = m.copy(); 310 311 final int rows = vec.getRowDimension(); 312 final int cols = vec.getColumnDimension(); 313 final double[][] vecarr = vec.getArray(); 314 315 for (int r=0; r<rows; r++) 316 for (int c=0; c<cols; c++) 317 vecarr[r][c] -= mean[c]; 318 319 //T = (Vt.Dt)^T == Dt.Vt 320 return vec.times(eigenvectors); 321 } 322 323 /** 324 * Project a vector by the basis. The vector 325 * is normalised by subtracting the mean and 326 * then multiplied by the basis. 327 * @param vector the vector to project 328 * @return projected vector 329 */ 330 public double[] project(double [] vector) { 331 Matrix vec = new Matrix(1, vector.length); 332 final double[][] vecarr = vec.getArray(); 333 334 for (int i=0; i<vector.length; i++) 335 vecarr[0][i] = vector[i] - mean[i]; 336 337 return vec.times(eigenvectors).getColumnPackedCopy(); 338 } 339}