View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.gmm;
31  
32  import java.util.Arrays;
33  import java.util.EnumSet;
34  
35  import org.apache.commons.math.util.MathUtils;
36  import org.openimaj.math.matrix.MatrixUtils;
37  import org.openimaj.math.statistics.MeanAndCovariance;
38  import org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian;
39  import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian;
40  import org.openimaj.math.statistics.distribution.FullMultivariateGaussian;
41  import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
42  import org.openimaj.math.statistics.distribution.MultivariateGaussian;
43  import org.openimaj.math.statistics.distribution.SphericalMultivariateGaussian;
44  import org.openimaj.ml.clustering.DoubleCentroidsResult;
45  import org.openimaj.ml.clustering.kmeans.DoubleKMeans;
46  import org.openimaj.util.array.ArrayUtils;
47  import org.openimaj.util.pair.IndependentPair;
48  
49  import Jama.Matrix;
50  import gnu.trove.list.array.TDoubleArrayList;
51  
52  /**
53   * Gaussian mixture model learning using the EM algorithm. Supports a range of
54   * different shapes Gaussian through different covariance matrix forms. An
55   * initialisation step is used to learn the initial means using K-Means,
56   * although this can be disabled in the constructor.
57   * <p>
58   * Implementation was originally inspired by the SciPy's "gmm.py".
59   *
60   * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
61   */
62  public class GaussianMixtureModelEM {
63  	/**
64  	 * Different forms of covariance matrix supported by the
65  	 * {@link GaussianMixtureModelEM}.
66  	 *
67  	 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
68  	 */
69  	public static enum CovarianceType {
70  		/**
71  		 * Spherical Gaussians: variance is the same along all axes and zero
72  		 * across-axes.
73  		 */
74  		Spherical {
75  			@Override
76  			protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
77  				double mean = 0;
78  
79  				for (int i = 0; i < cv.getRowDimension(); i++)
80  					for (int j = 0; j < cv.getColumnDimension(); j++)
81  						mean += cv.get(i, j);
82  				mean /= (cv.getColumnDimension() * cv.getRowDimension());
83  
84  				for (final MultivariateGaussian mg : gaussians) {
85  					((SphericalMultivariateGaussian) mg).variance = mean;
86  				}
87  			}
88  
89  			@Override
90  			protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
91  				final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
92  				for (int i = 0; i < ngauss; i++) {
93  					arr[i] = new SphericalMultivariateGaussian(ndims);
94  				}
95  
96  				return arr;
97  			}
98  
99  			@Override
100 			protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities,
101 					Matrix weightedXsum,
102 					double[] norm)
103 		{
104 				final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X));
105 
106 				for (int i = 0; i < gmm.gaussians.length; i++) {
107 					final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] });
108 					final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] });
109 
110 					final Matrix avgX2 = avgX2uwi.times(norm[i]);
111 					final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean;
112 					final Matrix avgMeans2 = MatrixUtils.pow(mu, 2);
113 					final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]);
114 					final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2),
115 							learner.minCovar);
116 
117 					((SphericalMultivariateGaussian) gmm.gaussians[i]).variance = MatrixUtils.sum(covar)
118 							/ X.getColumnDimension();
119 				}
120 			}
121 		},
122 		/**
123 		 * Gaussians with diagonal covariance matrices.
124 		 */
125 		Diagonal {
126 			@Override
127 			protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
128 				for (final MultivariateGaussian mg : gaussians) {
129 					((DiagonalMultivariateGaussian) mg).variance = MatrixUtils.diagVector(cv);
130 				}
131 			}
132 
133 			@Override
134 			protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
135 				final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
136 				for (int i = 0; i < ngauss; i++) {
137 					arr[i] = new DiagonalMultivariateGaussian(ndims);
138 				}
139 
140 				return arr;
141 			}
142 
143 			@Override
144 			protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities,
145 					Matrix weightedXsum,
146 					double[] norm)
147 		{
148 				final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X));
149 
150 				for (int i = 0; i < gmm.gaussians.length; i++) {
151 					final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] });
152 					final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] });
153 
154 					final Matrix avgX2 = avgX2uwi.times(norm[i]);
155 					final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean;
156 					final Matrix avgMeans2 = MatrixUtils.pow(mu, 2);
157 					final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]);
158 
159 					final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2),
160 							learner.minCovar);
161 
162 					((DiagonalMultivariateGaussian) gmm.gaussians[i]).variance = covar.getArray()[0];
163 				}
164 			}
165 		},
166 		/**
167 		 * Gaussians with full covariance
168 		 */
169 		Full {
170 			@Override
171 			protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
172 				final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
173 				for (int i = 0; i < ngauss; i++) {
174 					arr[i] = new FullMultivariateGaussian(ndims);
175 				}
176 
177 				return arr;
178 			}
179 
180 			@Override
181 			protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
182 				for (final MultivariateGaussian mg : gaussians) {
183 					((FullMultivariateGaussian) mg).covar = cv.copy();
184 				}
185 			}
186 
187 			@Override
188 			protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities,
189 					Matrix weightedXsum,
190 					double[] norm)
191 		{
192 				// Eq. 12 from K. Murphy,
193 				// "Fitting a Conditional Linear Gaussian Distribution"
194 				final int nfeatures = X.getColumnDimension();
195 				for (int c = 0; c < learner.nComponents; c++) {
196 					final Matrix post = responsibilities.getMatrix(0, X.getRowDimension() - 1, c, c).transpose();
197 
198 					final double factor = 1.0 / (ArrayUtils.sumValues(post.getArray()) + 10 * MathUtils.EPSILON);
199 
200 					final Matrix pXt = X.transpose();
201 					for (int i = 0; i < pXt.getRowDimension(); i++)
202 						for (int j = 0; j < pXt.getColumnDimension(); j++)
203 							pXt.set(i, j, pXt.get(i, j) * post.get(0, j));
204 
205 					final Matrix argcv = pXt.times(X).times(factor);
206 					final Matrix mu = ((FullMultivariateGaussian) gmm.gaussians[c]).mean;
207 
208 					((FullMultivariateGaussian) gmm.gaussians[c]).covar = argcv.minusEquals(mu.transpose().times(mu))
209 							.plusEquals(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar));
210 				}
211 			}
212 		},
213 		/**
214 		 * Gaussians with a tied covariance matrix; the same covariance matrix
215 		 * is shared by all the gaussians.
216 		 */
217 		Tied {
218 			// @Override
219 			// protected double[][] logProbability(double[][] x,
220 			// MultivariateGaussian[] gaussians)
221 			// {
222 			// final int ndim = x[0].length;
223 			// final int nmix = gaussians.length;
224 			// final int nsamples = x.length;
225 			// final Matrix X = new Matrix(x);
226 			//
227 			// final double[][] logProb = new double[nsamples][nmix];
228 			// final Matrix cv = ((FullMultivariateGaussian)
229 			// gaussians[0]).covar;
230 			//
231 			// final CholeskyDecomposition chol = cv.chol();
232 			// Matrix cvChol;
233 			// if (chol.isSPD()) {
234 			// cvChol = chol.getL();
235 			// } else {
236 			// // covar probably doesn't have enough samples, so
237 			// // recondition it
238 			// final Matrix m = cv.plus(Matrix.identity(ndim, ndim).timesEquals(
239 			// MixtureOfGaussians.MIN_COVAR_RECONDITION));
240 			// cvChol = m.chol().getL();
241 			// }
242 			//
243 			// double cvLogDet = 0;
244 			// final double[][] cvCholD = cvChol.getArray();
245 			// for (int j = 0; j < ndim; j++) {
246 			// cvLogDet += Math.log(cvCholD[j][j]);
247 			// }
248 			// cvLogDet *= 2;
249 			//
250 			// for (int i = 0; i < nmix; i++) {
251 			// final Matrix mu = ((FullMultivariateGaussian) gaussians[i]).mean;
252 			// final Matrix cvSol = cvChol.solve(MatrixUtils.minusRow(X,
253 			// mu.getArray()[0]).transpose())
254 			// .transpose();
255 			// for (int k = 0; k < nsamples; k++) {
256 			// double sum = 0;
257 			// for (int j = 0; j < ndim; j++) {
258 			// sum += cvSol.get(k, j) * cvSol.get(k, j);
259 			// }
260 			//
261 			// logProb[k][i] = -0.5 * (sum + cvLogDet + ndim * Math.log(2 *
262 			// Math.PI));
263 			// }
264 			// }
265 			//
266 			// return logProb;
267 			// }
268 
269 			@Override
270 			protected void setCovariances(MultivariateGaussian[] gaussians,
271 					Matrix cv)
272 		{
273 				for (final MultivariateGaussian mg : gaussians) {
274 					((FullMultivariateGaussian) mg).covar = cv;
275 				}
276 			}
277 
278 			@Override
279 			protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
280 				final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
281 				final Matrix covar = new Matrix(ndims, ndims);
282 
283 				for (int i = 0; i < ngauss; i++) {
284 					arr[i] = new FullMultivariateGaussian(new Matrix(1, ndims), covar);
285 				}
286 
287 				return arr;
288 			}
289 
290 			@Override
291 			protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities,
292 					Matrix weightedXsum, double[] norm)
293 		{
294 				// Eq. 15 from K. Murphy, "Fitting a Conditional Linear Gaussian
295 				final int nfeatures = X.getColumnDimension();
296 
297 				final Matrix avgX2 = X.transpose().times(X);
298 				final double[][] mudata = new double[gmm.gaussians.length][];
299 				for (int i = 0; i < mudata.length; i++)
300 					mudata[i] = ((FullMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0];
301 				final Matrix mu = new Matrix(mudata);
302 
303 				final Matrix avgMeans2 = mu.transpose().times(weightedXsum);
304 				final Matrix covar = avgX2.minus(avgMeans2)
305 						.plus(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar))
306 						.times(1.0 / X.getRowDimension());
307 
308 				for (int i = 0; i < learner.nComponents; i++)
309 					((FullMultivariateGaussian) gmm.gaussians[i]).covar = covar;
310 			}
311 		};
312 
313 		protected abstract MultivariateGaussian[] createGaussians(int ngauss, int ndims);
314 
315 		protected abstract void setCovariances(MultivariateGaussian[] gaussians, Matrix cv);
316 
317 		/**
318 		 * Mode specific maximisation-step. Implementors should use the state to
319 		 * update the covariance of each of the
320 		 * {@link GaussianMixtureModelEM#gaussians}.
321 		 *
322 		 * @param gmm
323 		 *            the mixture model being learned
324 		 * @param X
325 		 *            the data
326 		 * @param responsibilities
327 		 *            matrix with the same number of rows as X where each col is
328 		 *            the amount that the data point belongs to each gaussian
329 		 * @param weightedXsum
330 		 *            responsibilities.T * X
331 		 * @param inverseWeights
332 		 *            1/weights
333 		 */
334 		protected abstract void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X,
335 				Matrix responsibilities, Matrix weightedXsum, double[] inverseWeights);
336 	}
337 
338 	/**
339 	 * Options for controlling what gets updated during the initialisation
340 	 * and/or iterations.
341 	 *
342 	 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
343 	 */
344 	public static enum UpdateOptions {
345 		/**
346 		 * Update the means
347 		 */
348 		Means,
349 		/**
350 		 * Update the weights
351 		 */
352 		Weights,
353 		/**
354 		 * Update the covariances
355 		 */
356 		Covariances
357 	}
358 
359 	protected static class EMGMM extends MixtureOfGaussians {
360 		EMGMM(int nComponents) {
361 			super(null, null);
362 
363 			this.weights = new double[nComponents];
364 			Arrays.fill(this.weights, 1.0 / nComponents);
365 		}
366 	}
367 
368 	private static final double DEFAULT_THRESH = 1e-2;
369 	private static final double DEFAULT_MIN_COVAR = 1e-3;
370 	private static final int DEFAULT_NITERS = 100;
371 	private static final int DEFAULT_NINIT = 1;
372 
373 	CovarianceType ctype;
374 	int nComponents;
375 	private double thresh;
376 	private double minCovar;
377 	private int nIters;
378 	private int nInit;
379 
380 	private boolean converged = false;
381 	private EnumSet<UpdateOptions> initOpts;
382 	private EnumSet<UpdateOptions> iterOpts;
383 
384 	/**
385 	 * Construct with the given arguments.
386 	 *
387 	 * @param nComponents
388 	 *            the number of gaussian components
389 	 * @param ctype
390 	 *            the form of the covariance matrices
391 	 * @param thresh
392 	 *            the threshold at which to stop iterating
393 	 * @param minCovar
394 	 *            the minimum value allowed in the diagonal of the estimated
395 	 *            covariance matrices to prevent overfitting
396 	 * @param nIters
397 	 *            the maximum number of iterations
398 	 * @param nInit
399 	 *            the number of runs of the algorithm to perform; the best
400 	 *            result will be kept.
401 	 * @param iterOpts
402 	 *            options controlling what is updated during iteration
403 	 * @param initOpts
404 	 *            options controlling what is updated during initialisation.
405 	 *            Enabling the {@link UpdateOptions#Means} option will cause
406 	 *            K-Means to be used to generate initial starting points for the
407 	 *            means.
408 	 */
409 	public GaussianMixtureModelEM(int nComponents, CovarianceType ctype, double thresh, double minCovar,
410 			int nIters, int nInit, EnumSet<UpdateOptions> iterOpts, EnumSet<UpdateOptions> initOpts)
411 	{
412 		this.ctype = ctype;
413 		this.nComponents = nComponents;
414 		this.thresh = thresh;
415 		this.minCovar = minCovar;
416 		this.nIters = nIters;
417 		this.nInit = nInit;
418 		this.iterOpts = iterOpts;
419 		this.initOpts = initOpts;
420 
421 		if (nInit < 1) {
422 			throw new IllegalArgumentException("GMM estimation requires at least one run");
423 		}
424 		this.converged = false;
425 	}
426 
427 	/**
428 	 * Construct with the given arguments.
429 	 *
430 	 * @param nComponents
431 	 *            the number of gaussian components
432 	 * @param ctype
433 	 *            the form of the covariance matrices
434 	 */
435 	public GaussianMixtureModelEM(int nComponents, CovarianceType ctype) {
436 		this(nComponents, ctype, DEFAULT_THRESH, DEFAULT_MIN_COVAR, DEFAULT_NITERS, DEFAULT_NINIT, EnumSet
437 				.allOf(UpdateOptions.class), EnumSet.allOf(UpdateOptions.class));
438 	}
439 
440 	/**
441 	 * Get's the convergence state of the algorithm. Will return false if
442 	 * {@link #estimate(double[][])} has not been called, or if the last call to
443 	 * {@link #estimate(double[][])} failed to reach convergence before running
444 	 * out of iterations.
445 	 *
446 	 * @return true if the last call to {@link #estimate(double[][])} reached
447 	 *         convergence; false otherwise
448 	 */
449 	public boolean hasConverged() {
450 		return converged;
451 	}
452 
453 	/**
454 	 * Estimate a new {@link MixtureOfGaussians} from the given data. Use
455 	 * {@link #hasConverged()} to check whether the EM algorithm reached
456 	 * convergence in the estimation of the returned model.
457 	 *
458 	 * @param X
459 	 *            the data matrix.
460 	 * @return the generated GMM.
461 	 */
462 	public MixtureOfGaussians estimate(Matrix X) {
463 		return estimate(X.getArray());
464 	}
465 
466 	/**
467 	 * Estimate a new {@link MixtureOfGaussians} from the given data. Use
468 	 * {@link #hasConverged()} to check whether the EM algorithm reached
469 	 * convergence in the estimation of the returned model.
470 	 *
471 	 * @param X
472 	 *            the data array.
473 	 * @return the generated GMM.
474 	 */
475 	public MixtureOfGaussians estimate(double[][] X) {
476 		final EMGMM gmm = new EMGMM(nComponents);
477 
478 		if (X.length < nComponents)
479 			throw new IllegalArgumentException(String.format(
480 					"GMM estimation with %d components, but got only %d samples", nComponents, X.length));
481 
482 		double max_log_prob = Double.NEGATIVE_INFINITY;
483 
484 		for (int j = 0; j < nInit; j++) {
485 			gmm.gaussians = ctype.createGaussians(nComponents, X[0].length);
486 
487 			if (initOpts.contains(UpdateOptions.Means)) {
488 				// initialise using k-means
489 				final DoubleKMeans km = DoubleKMeans.createExact(nComponents);
490 				final DoubleCentroidsResult means = km.cluster(X);
491 
492 				for (int i = 0; i < nComponents; i++) {
493 					((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0] = means.centroids[i];
494 				}
495 			}
496 
497 			if (initOpts.contains(UpdateOptions.Weights)) {
498 				gmm.weights = new double[nComponents];
499 				Arrays.fill(gmm.weights, 1.0 / nComponents);
500 			}
501 
502 			if (initOpts.contains(UpdateOptions.Covariances)) {
503 				// cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1])
504 				final Matrix cv = MeanAndCovariance.computeCovariance(X);
505 
506 				ctype.setCovariances(gmm.gaussians, cv);
507 			}
508 
509 			// EM algorithm
510 			final TDoubleArrayList log_likelihood = new TDoubleArrayList();
511 
512 			// reset converged to false
513 			converged = false;
514 			double[] bestWeights = null;
515 			MultivariateGaussian[] bestMixture = null;
516 			for (int i = 0; i < nIters; i++) {
517 				// Expectation step
518 				final IndependentPair<double[], double[][]> score = gmm.scoreSamples(X);
519 				final double[] curr_log_likelihood = score.firstObject();
520 				final double[][] responsibilities = score.secondObject();
521 				log_likelihood.add(ArrayUtils.sumValues(curr_log_likelihood));
522 
523 				// Check for convergence.
524 				if (i > 0 && Math.abs(log_likelihood.get(i) - log_likelihood.get(i - 1)) < thresh) {
525 					converged = true;
526 					break;
527 				}
528 
529 				// Perform the maximisation step
530 				mstep(gmm, X, responsibilities);
531 
532 				// if the results are better, keep it
533 				if (nIters > 0) {
534 					if (log_likelihood.getQuick(i) > max_log_prob) {
535 						max_log_prob = log_likelihood.getQuick(i);
536 						bestWeights = gmm.weights;
537 						bestMixture = gmm.gaussians;
538 					}
539 				}
540 
541 				// check the existence of an init param that was not subject to
542 				// likelihood computation issue.
543 				if (Double.isInfinite(max_log_prob) && nIters > 0) {
544 					throw new RuntimeException(
545 							"EM algorithm was never able to compute a valid likelihood given initial " +
546 									"parameters. Try different init parameters (or increasing n_init) or " +
547 									"check for degenerate data.");
548 				}
549 
550 				if (nIters > 0) {
551 					gmm.gaussians = bestMixture;
552 					gmm.weights = bestWeights;
553 				}
554 			}
555 		}
556 
557 		return gmm;
558 	}
559 
560 	protected void mstep(EMGMM gmm, double[][] X, double[][] responsibilities) {
561 		final double[] weights = ArrayUtils.colSum(responsibilities);
562 		final Matrix resMat = new Matrix(responsibilities);
563 		final Matrix Xmat = new Matrix(X);
564 
565 		final Matrix weighted_X_sum = resMat.transpose().times(Xmat);
566 		final double[] inverse_weights = new double[weights.length];
567 		for (int i = 0; i < inverse_weights.length; i++)
568 			inverse_weights[i] = 1.0 / (weights[i] + 10 * MathUtils.EPSILON);
569 
570 		if (iterOpts.contains(UpdateOptions.Weights)) {
571 			final double sum = ArrayUtils.sumValues(weights);
572 			for (int i = 0; i < weights.length; i++) {
573 				gmm.weights[i] = (weights[i] / (sum + 10 * MathUtils.EPSILON) + MathUtils.EPSILON);
574 			}
575 		}
576 
577 		if (iterOpts.contains(UpdateOptions.Means)) {
578 			// self.means_ = weighted_X_sum * inverse_weights
579 			final double[][] wx = weighted_X_sum.getArray();
580 
581 			for (int i = 0; i < nComponents; i++) {
582 				final double[][] m = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray();
583 
584 				for (int j = 0; j < m[0].length; j++) {
585 					m[0][j] = wx[i][j] * inverse_weights[i];
586 				}
587 			}
588 		}
589 
590 		if (iterOpts.contains(UpdateOptions.Covariances)) {
591 			ctype.mstep(gmm, this, Xmat, resMat, weighted_X_sum, inverse_weights);
592 		}
593 	}
594 
595 	@Override
596 	public GaussianMixtureModelEM clone() {
597 		try {
598 			return (GaussianMixtureModelEM) super.clone();
599 		} catch (final CloneNotSupportedException e) {
600 			throw new RuntimeException(e);
601 		}
602 	}
603 }