001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.statistics.distribution; 031 032import java.util.Random; 033 034import org.openimaj.math.matrix.MatrixUtils; 035 036import Jama.Matrix; 037 038/** 039 * Abstract base class for {@link MultivariateGaussian} implementations 040 * 041 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 042 * 043 */ 044public abstract class AbstractMultivariateGaussian implements MultivariateGaussian { 045 /** 046 * The mean vector 047 */ 048 public Matrix mean; 049 050 @Override 051 public Matrix getMean() { 052 return mean; 053 } 054 055 @Override 056 public double[] sample(Random rng) { 057 final int N = mean.getColumnDimension(); 058 final Matrix chol = getCovariance().chol().getL(); 059 final Matrix vec = new Matrix(N, 1); 060 061 for (int i = 0; i < N; i++) 062 vec.set(i, 0, rng.nextGaussian()); 063 064 final Matrix result = this.mean.plus(chol.times(vec).transpose()); 065 066 return result.getArray()[0]; 067 } 068 069 @Override 070 public double[][] sample(int nsamples, Random rng) { 071 if (nsamples == 0) 072 return new double[0][0]; 073 074 final int N = mean.getColumnDimension(); 075 final Matrix chol = getCovariance().chol().getL(); 076 final Matrix vec = new Matrix(N, nsamples); 077 078 for (int i = 0; i < N; i++) 079 for (int j = 0; j < nsamples; j++) 080 vec.set(i, j, rng.nextGaussian()); 081 082 final Matrix result = chol.times(vec).transpose(); 083 for (int i = 0; i < result.getRowDimension(); i++) 084 for (int j = 0; j < result.getColumnDimension(); j++) 085 result.set(i, j, result.get(i, j) + mean.get(0, j)); 086 087 return result.getArray(); 088 } 089 090 @Override 091 public int numDims() { 092 return mean.getColumnDimension(); 093 } 094 095 @Override 096 public double estimateProbability(double[] sample) { 097 final int N = mean.getColumnDimension(); 098 final Matrix inv_covar = getCovariance().inverse(); 099 final double pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * getCovariance().det())); 100 101 final Matrix xm = new Matrix(1, N); 102 for (int i = 0; i < N; i++) 103 xm.set(0, i, sample[i] - mean.get(0, i)); 104 105 final Matrix xmt = xm.transpose(); 106 final double v = xm.times(inv_covar.times(xmt)).get(0, 0); 107 108 return pdf_const_factor * Math.exp(-0.5 * v); 109 } 110 111 @Override 112 public double estimateLogProbability(double[] sample) { 113 final int N = mean.getColumnDimension(); 114 final Matrix inv_covar = getCovariance().inverse(); 115 final double cov_det = getCovariance().det(); 116 final double pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * cov_det)); 117 118 final Matrix xm = new Matrix(1, N); 119 for (int i = 0; i < N; i++) 120 xm.set(0, i, sample[i] - mean.get(0, i)); 121 122 final Matrix xmt = xm.transpose(); 123 final double v = xm.times(inv_covar.times(xmt)).get(0, 0); 124 125 return Math.log(pdf_const_factor) + (-0.5 * v); 126 } 127 128 @Override 129 public String toString() { 130 // only pretty print with low dimensionality 131 if (this.numDims() < 5) 132 return String.format("MultivariateGaussian[mean=%s,covar=%s]", MatrixUtils.toMatlabString(mean).trim(), 133 MatrixUtils.toMatlabString(getCovariance())); 134 return super.toString(); 135 } 136 137 @Override 138 public double[] estimateLogProbability(double[][] x) { 139 final double[] lps = new double[x.length]; 140 for (int i = 0; i < x.length; i++) 141 lps[i] = estimateLogProbability(x[i]); 142 return lps; 143 } 144}