001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.gmm; 031 032import java.util.Arrays; 033import java.util.EnumSet; 034 035import org.apache.commons.math.util.MathUtils; 036import org.openimaj.math.matrix.MatrixUtils; 037import org.openimaj.math.statistics.MeanAndCovariance; 038import org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian; 039import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian; 040import org.openimaj.math.statistics.distribution.FullMultivariateGaussian; 041import org.openimaj.math.statistics.distribution.MixtureOfGaussians; 042import org.openimaj.math.statistics.distribution.MultivariateGaussian; 043import org.openimaj.math.statistics.distribution.SphericalMultivariateGaussian; 044import org.openimaj.ml.clustering.DoubleCentroidsResult; 045import org.openimaj.ml.clustering.kmeans.DoubleKMeans; 046import org.openimaj.util.array.ArrayUtils; 047import org.openimaj.util.pair.IndependentPair; 048 049import Jama.Matrix; 050import gnu.trove.list.array.TDoubleArrayList; 051 052/** 053 * Gaussian mixture model learning using the EM algorithm. Supports a range of 054 * different shapes Gaussian through different covariance matrix forms. An 055 * initialisation step is used to learn the initial means using K-Means, 056 * although this can be disabled in the constructor. 057 * <p> 058 * Implementation was originally inspired by the SciPy's "gmm.py". 059 * 060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 061 */ 062public class GaussianMixtureModelEM { 063 /** 064 * Different forms of covariance matrix supported by the 065 * {@link GaussianMixtureModelEM}. 066 * 067 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 068 */ 069 public static enum CovarianceType { 070 /** 071 * Spherical Gaussians: variance is the same along all axes and zero 072 * across-axes. 073 */ 074 Spherical { 075 @Override 076 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) { 077 double mean = 0; 078 079 for (int i = 0; i < cv.getRowDimension(); i++) 080 for (int j = 0; j < cv.getColumnDimension(); j++) 081 mean += cv.get(i, j); 082 mean /= (cv.getColumnDimension() * cv.getRowDimension()); 083 084 for (final MultivariateGaussian mg : gaussians) { 085 ((SphericalMultivariateGaussian) mg).variance = mean; 086 } 087 } 088 089 @Override 090 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 091 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 092 for (int i = 0; i < ngauss; i++) { 093 arr[i] = new SphericalMultivariateGaussian(ndims); 094 } 095 096 return arr; 097 } 098 099 @Override 100 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 101 Matrix weightedXsum, 102 double[] norm) 103 { 104 final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X)); 105 106 for (int i = 0; i < gmm.gaussians.length; i++) { 107 final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] }); 108 final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] }); 109 110 final Matrix avgX2 = avgX2uwi.times(norm[i]); 111 final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean; 112 final Matrix avgMeans2 = MatrixUtils.pow(mu, 2); 113 final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]); 114 final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2), 115 learner.minCovar); 116 117 ((SphericalMultivariateGaussian) gmm.gaussians[i]).variance = MatrixUtils.sum(covar) 118 / X.getColumnDimension(); 119 } 120 } 121 }, 122 /** 123 * Gaussians with diagonal covariance matrices. 124 */ 125 Diagonal { 126 @Override 127 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) { 128 for (final MultivariateGaussian mg : gaussians) { 129 ((DiagonalMultivariateGaussian) mg).variance = MatrixUtils.diagVector(cv); 130 } 131 } 132 133 @Override 134 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 135 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 136 for (int i = 0; i < ngauss; i++) { 137 arr[i] = new DiagonalMultivariateGaussian(ndims); 138 } 139 140 return arr; 141 } 142 143 @Override 144 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 145 Matrix weightedXsum, 146 double[] norm) 147 { 148 final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X)); 149 150 for (int i = 0; i < gmm.gaussians.length; i++) { 151 final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] }); 152 final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] }); 153 154 final Matrix avgX2 = avgX2uwi.times(norm[i]); 155 final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean; 156 final Matrix avgMeans2 = MatrixUtils.pow(mu, 2); 157 final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]); 158 159 final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2), 160 learner.minCovar); 161 162 ((DiagonalMultivariateGaussian) gmm.gaussians[i]).variance = covar.getArray()[0]; 163 } 164 } 165 }, 166 /** 167 * Gaussians with full covariance 168 */ 169 Full { 170 @Override 171 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 172 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 173 for (int i = 0; i < ngauss; i++) { 174 arr[i] = new FullMultivariateGaussian(ndims); 175 } 176 177 return arr; 178 } 179 180 @Override 181 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) { 182 for (final MultivariateGaussian mg : gaussians) { 183 ((FullMultivariateGaussian) mg).covar = cv.copy(); 184 } 185 } 186 187 @Override 188 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 189 Matrix weightedXsum, 190 double[] norm) 191 { 192 // Eq. 12 from K. Murphy, 193 // "Fitting a Conditional Linear Gaussian Distribution" 194 final int nfeatures = X.getColumnDimension(); 195 for (int c = 0; c < learner.nComponents; c++) { 196 final Matrix post = responsibilities.getMatrix(0, X.getRowDimension() - 1, c, c).transpose(); 197 198 final double factor = 1.0 / (ArrayUtils.sumValues(post.getArray()) + 10 * MathUtils.EPSILON); 199 200 final Matrix pXt = X.transpose(); 201 for (int i = 0; i < pXt.getRowDimension(); i++) 202 for (int j = 0; j < pXt.getColumnDimension(); j++) 203 pXt.set(i, j, pXt.get(i, j) * post.get(0, j)); 204 205 final Matrix argcv = pXt.times(X).times(factor); 206 final Matrix mu = ((FullMultivariateGaussian) gmm.gaussians[c]).mean; 207 208 ((FullMultivariateGaussian) gmm.gaussians[c]).covar = argcv.minusEquals(mu.transpose().times(mu)) 209 .plusEquals(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar)); 210 } 211 } 212 }, 213 /** 214 * Gaussians with a tied covariance matrix; the same covariance matrix 215 * is shared by all the gaussians. 216 */ 217 Tied { 218 // @Override 219 // protected double[][] logProbability(double[][] x, 220 // MultivariateGaussian[] gaussians) 221 // { 222 // final int ndim = x[0].length; 223 // final int nmix = gaussians.length; 224 // final int nsamples = x.length; 225 // final Matrix X = new Matrix(x); 226 // 227 // final double[][] logProb = new double[nsamples][nmix]; 228 // final Matrix cv = ((FullMultivariateGaussian) 229 // gaussians[0]).covar; 230 // 231 // final CholeskyDecomposition chol = cv.chol(); 232 // Matrix cvChol; 233 // if (chol.isSPD()) { 234 // cvChol = chol.getL(); 235 // } else { 236 // // covar probably doesn't have enough samples, so 237 // // recondition it 238 // final Matrix m = cv.plus(Matrix.identity(ndim, ndim).timesEquals( 239 // MixtureOfGaussians.MIN_COVAR_RECONDITION)); 240 // cvChol = m.chol().getL(); 241 // } 242 // 243 // double cvLogDet = 0; 244 // final double[][] cvCholD = cvChol.getArray(); 245 // for (int j = 0; j < ndim; j++) { 246 // cvLogDet += Math.log(cvCholD[j][j]); 247 // } 248 // cvLogDet *= 2; 249 // 250 // for (int i = 0; i < nmix; i++) { 251 // final Matrix mu = ((FullMultivariateGaussian) gaussians[i]).mean; 252 // final Matrix cvSol = cvChol.solve(MatrixUtils.minusRow(X, 253 // mu.getArray()[0]).transpose()) 254 // .transpose(); 255 // for (int k = 0; k < nsamples; k++) { 256 // double sum = 0; 257 // for (int j = 0; j < ndim; j++) { 258 // sum += cvSol.get(k, j) * cvSol.get(k, j); 259 // } 260 // 261 // logProb[k][i] = -0.5 * (sum + cvLogDet + ndim * Math.log(2 * 262 // Math.PI)); 263 // } 264 // } 265 // 266 // return logProb; 267 // } 268 269 @Override 270 protected void setCovariances(MultivariateGaussian[] gaussians, 271 Matrix cv) 272 { 273 for (final MultivariateGaussian mg : gaussians) { 274 ((FullMultivariateGaussian) mg).covar = cv; 275 } 276 } 277 278 @Override 279 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 280 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 281 final Matrix covar = new Matrix(ndims, ndims); 282 283 for (int i = 0; i < ngauss; i++) { 284 arr[i] = new FullMultivariateGaussian(new Matrix(1, ndims), covar); 285 } 286 287 return arr; 288 } 289 290 @Override 291 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 292 Matrix weightedXsum, double[] norm) 293 { 294 // Eq. 15 from K. Murphy, "Fitting a Conditional Linear Gaussian 295 final int nfeatures = X.getColumnDimension(); 296 297 final Matrix avgX2 = X.transpose().times(X); 298 final double[][] mudata = new double[gmm.gaussians.length][]; 299 for (int i = 0; i < mudata.length; i++) 300 mudata[i] = ((FullMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0]; 301 final Matrix mu = new Matrix(mudata); 302 303 final Matrix avgMeans2 = mu.transpose().times(weightedXsum); 304 final Matrix covar = avgX2.minus(avgMeans2) 305 .plus(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar)) 306 .times(1.0 / X.getRowDimension()); 307 308 for (int i = 0; i < learner.nComponents; i++) 309 ((FullMultivariateGaussian) gmm.gaussians[i]).covar = covar; 310 } 311 }; 312 313 protected abstract MultivariateGaussian[] createGaussians(int ngauss, int ndims); 314 315 protected abstract void setCovariances(MultivariateGaussian[] gaussians, Matrix cv); 316 317 /** 318 * Mode specific maximisation-step. Implementors should use the state to 319 * update the covariance of each of the 320 * {@link GaussianMixtureModelEM#gaussians}. 321 * 322 * @param gmm 323 * the mixture model being learned 324 * @param X 325 * the data 326 * @param responsibilities 327 * matrix with the same number of rows as X where each col is 328 * the amount that the data point belongs to each gaussian 329 * @param weightedXsum 330 * responsibilities.T * X 331 * @param inverseWeights 332 * 1/weights 333 */ 334 protected abstract void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, 335 Matrix responsibilities, Matrix weightedXsum, double[] inverseWeights); 336 } 337 338 /** 339 * Options for controlling what gets updated during the initialisation 340 * and/or iterations. 341 * 342 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 343 */ 344 public static enum UpdateOptions { 345 /** 346 * Update the means 347 */ 348 Means, 349 /** 350 * Update the weights 351 */ 352 Weights, 353 /** 354 * Update the covariances 355 */ 356 Covariances 357 } 358 359 protected static class EMGMM extends MixtureOfGaussians { 360 EMGMM(int nComponents) { 361 super(null, null); 362 363 this.weights = new double[nComponents]; 364 Arrays.fill(this.weights, 1.0 / nComponents); 365 } 366 } 367 368 private static final double DEFAULT_THRESH = 1e-2; 369 private static final double DEFAULT_MIN_COVAR = 1e-3; 370 private static final int DEFAULT_NITERS = 100; 371 private static final int DEFAULT_NINIT = 1; 372 373 CovarianceType ctype; 374 int nComponents; 375 private double thresh; 376 private double minCovar; 377 private int nIters; 378 private int nInit; 379 380 private boolean converged = false; 381 private EnumSet<UpdateOptions> initOpts; 382 private EnumSet<UpdateOptions> iterOpts; 383 384 /** 385 * Construct with the given arguments. 386 * 387 * @param nComponents 388 * the number of gaussian components 389 * @param ctype 390 * the form of the covariance matrices 391 * @param thresh 392 * the threshold at which to stop iterating 393 * @param minCovar 394 * the minimum value allowed in the diagonal of the estimated 395 * covariance matrices to prevent overfitting 396 * @param nIters 397 * the maximum number of iterations 398 * @param nInit 399 * the number of runs of the algorithm to perform; the best 400 * result will be kept. 401 * @param iterOpts 402 * options controlling what is updated during iteration 403 * @param initOpts 404 * options controlling what is updated during initialisation. 405 * Enabling the {@link UpdateOptions#Means} option will cause 406 * K-Means to be used to generate initial starting points for the 407 * means. 408 */ 409 public GaussianMixtureModelEM(int nComponents, CovarianceType ctype, double thresh, double minCovar, 410 int nIters, int nInit, EnumSet<UpdateOptions> iterOpts, EnumSet<UpdateOptions> initOpts) 411 { 412 this.ctype = ctype; 413 this.nComponents = nComponents; 414 this.thresh = thresh; 415 this.minCovar = minCovar; 416 this.nIters = nIters; 417 this.nInit = nInit; 418 this.iterOpts = iterOpts; 419 this.initOpts = initOpts; 420 421 if (nInit < 1) { 422 throw new IllegalArgumentException("GMM estimation requires at least one run"); 423 } 424 this.converged = false; 425 } 426 427 /** 428 * Construct with the given arguments. 429 * 430 * @param nComponents 431 * the number of gaussian components 432 * @param ctype 433 * the form of the covariance matrices 434 */ 435 public GaussianMixtureModelEM(int nComponents, CovarianceType ctype) { 436 this(nComponents, ctype, DEFAULT_THRESH, DEFAULT_MIN_COVAR, DEFAULT_NITERS, DEFAULT_NINIT, EnumSet 437 .allOf(UpdateOptions.class), EnumSet.allOf(UpdateOptions.class)); 438 } 439 440 /** 441 * Get's the convergence state of the algorithm. Will return false if 442 * {@link #estimate(double[][])} has not been called, or if the last call to 443 * {@link #estimate(double[][])} failed to reach convergence before running 444 * out of iterations. 445 * 446 * @return true if the last call to {@link #estimate(double[][])} reached 447 * convergence; false otherwise 448 */ 449 public boolean hasConverged() { 450 return converged; 451 } 452 453 /** 454 * Estimate a new {@link MixtureOfGaussians} from the given data. Use 455 * {@link #hasConverged()} to check whether the EM algorithm reached 456 * convergence in the estimation of the returned model. 457 * 458 * @param X 459 * the data matrix. 460 * @return the generated GMM. 461 */ 462 public MixtureOfGaussians estimate(Matrix X) { 463 return estimate(X.getArray()); 464 } 465 466 /** 467 * Estimate a new {@link MixtureOfGaussians} from the given data. Use 468 * {@link #hasConverged()} to check whether the EM algorithm reached 469 * convergence in the estimation of the returned model. 470 * 471 * @param X 472 * the data array. 473 * @return the generated GMM. 474 */ 475 public MixtureOfGaussians estimate(double[][] X) { 476 final EMGMM gmm = new EMGMM(nComponents); 477 478 if (X.length < nComponents) 479 throw new IllegalArgumentException(String.format( 480 "GMM estimation with %d components, but got only %d samples", nComponents, X.length)); 481 482 double max_log_prob = Double.NEGATIVE_INFINITY; 483 484 for (int j = 0; j < nInit; j++) { 485 gmm.gaussians = ctype.createGaussians(nComponents, X[0].length); 486 487 if (initOpts.contains(UpdateOptions.Means)) { 488 // initialise using k-means 489 final DoubleKMeans km = DoubleKMeans.createExact(nComponents); 490 final DoubleCentroidsResult means = km.cluster(X); 491 492 for (int i = 0; i < nComponents; i++) { 493 ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0] = means.centroids[i]; 494 } 495 } 496 497 if (initOpts.contains(UpdateOptions.Weights)) { 498 gmm.weights = new double[nComponents]; 499 Arrays.fill(gmm.weights, 1.0 / nComponents); 500 } 501 502 if (initOpts.contains(UpdateOptions.Covariances)) { 503 // cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1]) 504 final Matrix cv = MeanAndCovariance.computeCovariance(X); 505 506 ctype.setCovariances(gmm.gaussians, cv); 507 } 508 509 // EM algorithm 510 final TDoubleArrayList log_likelihood = new TDoubleArrayList(); 511 512 // reset converged to false 513 converged = false; 514 double[] bestWeights = null; 515 MultivariateGaussian[] bestMixture = null; 516 for (int i = 0; i < nIters; i++) { 517 // Expectation step 518 final IndependentPair<double[], double[][]> score = gmm.scoreSamples(X); 519 final double[] curr_log_likelihood = score.firstObject(); 520 final double[][] responsibilities = score.secondObject(); 521 log_likelihood.add(ArrayUtils.sumValues(curr_log_likelihood)); 522 523 // Check for convergence. 524 if (i > 0 && Math.abs(log_likelihood.get(i) - log_likelihood.get(i - 1)) < thresh) { 525 converged = true; 526 break; 527 } 528 529 // Perform the maximisation step 530 mstep(gmm, X, responsibilities); 531 532 // if the results are better, keep it 533 if (nIters > 0) { 534 if (log_likelihood.getQuick(i) > max_log_prob) { 535 max_log_prob = log_likelihood.getQuick(i); 536 bestWeights = gmm.weights; 537 bestMixture = gmm.gaussians; 538 } 539 } 540 541 // check the existence of an init param that was not subject to 542 // likelihood computation issue. 543 if (Double.isInfinite(max_log_prob) && nIters > 0) { 544 throw new RuntimeException( 545 "EM algorithm was never able to compute a valid likelihood given initial " + 546 "parameters. Try different init parameters (or increasing n_init) or " + 547 "check for degenerate data."); 548 } 549 550 if (nIters > 0) { 551 gmm.gaussians = bestMixture; 552 gmm.weights = bestWeights; 553 } 554 } 555 } 556 557 return gmm; 558 } 559 560 protected void mstep(EMGMM gmm, double[][] X, double[][] responsibilities) { 561 final double[] weights = ArrayUtils.colSum(responsibilities); 562 final Matrix resMat = new Matrix(responsibilities); 563 final Matrix Xmat = new Matrix(X); 564 565 final Matrix weighted_X_sum = resMat.transpose().times(Xmat); 566 final double[] inverse_weights = new double[weights.length]; 567 for (int i = 0; i < inverse_weights.length; i++) 568 inverse_weights[i] = 1.0 / (weights[i] + 10 * MathUtils.EPSILON); 569 570 if (iterOpts.contains(UpdateOptions.Weights)) { 571 final double sum = ArrayUtils.sumValues(weights); 572 for (int i = 0; i < weights.length; i++) { 573 gmm.weights[i] = (weights[i] / (sum + 10 * MathUtils.EPSILON) + MathUtils.EPSILON); 574 } 575 } 576 577 if (iterOpts.contains(UpdateOptions.Means)) { 578 // self.means_ = weighted_X_sum * inverse_weights 579 final double[][] wx = weighted_X_sum.getArray(); 580 581 for (int i = 0; i < nComponents; i++) { 582 final double[][] m = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray(); 583 584 for (int j = 0; j < m[0].length; j++) { 585 m[0][j] = wx[i][j] * inverse_weights[i]; 586 } 587 } 588 } 589 590 if (iterOpts.contains(UpdateOptions.Covariances)) { 591 ctype.mstep(gmm, this, Xmat, resMat, weighted_X_sum, inverse_weights); 592 } 593 } 594 595 @Override 596 public GaussianMixtureModelEM clone() { 597 try { 598 return (GaussianMixtureModelEM) super.clone(); 599 } catch (final CloneNotSupportedException e) { 600 throw new RuntimeException(e); 601 } 602 } 603}