001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.statistics.distribution; 031 032import java.util.Random; 033 034import org.openimaj.math.statistics.MeanAndCovariance; 035 036import Jama.Matrix; 037 038/** 039 * A single multidimensional Gaussian. This implementation computes the inverse 040 * and Cholesky decomposition of the covariance matrix and caches them for 041 * efficient sampling and probability computation. 042 * 043 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 044 * 045 */ 046public class CachingMultivariateGaussian extends AbstractMultivariateGaussian implements MultivariateGaussian { 047 protected Matrix covar; 048 protected int N; 049 050 private Matrix inv_covar; 051 private double pdf_const_factor; 052 private Matrix chol; 053 054 protected CachingMultivariateGaussian() { 055 } 056 057 /** 058 * Construct the Gaussian with the provided center and covariance 059 * 060 * @param mean 061 * centre of the Gaussian 062 * @param covar 063 * covariance of the Gaussian 064 */ 065 public CachingMultivariateGaussian(Matrix mean, Matrix covar) { 066 N = mean.getColumnDimension(); 067 this.mean = mean; 068 this.covar = covar; 069 cacheValues(); 070 } 071 072 /** 073 * Construct the Gaussian with the zero mean and unit variance 074 * 075 * @param ndims 076 * number of dimensions 077 */ 078 public CachingMultivariateGaussian(int ndims) { 079 N = ndims; 080 this.mean = new Matrix(1, N); 081 this.covar = Matrix.identity(N, N); 082 cacheValues(); 083 } 084 085 protected void cacheValues() { 086 inv_covar = covar.inverse(); 087 pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * covar.det())); 088 089 chol = covar.chol().getL(); 090 } 091 092 /** 093 * Estimate a multidimensional Gaussian from the data 094 * 095 * @param samples 096 * the data 097 * @return the Gaussian with the best fit to the data 098 */ 099 public static CachingMultivariateGaussian estimate(float[][] samples) { 100 final int ndims = samples[0].length; 101 102 final CachingMultivariateGaussian gauss = new CachingMultivariateGaussian(); 103 gauss.N = ndims; 104 105 final MeanAndCovariance res = new MeanAndCovariance(samples); 106 gauss.mean = res.mean; 107 gauss.covar = res.covar; 108 109 gauss.cacheValues(); 110 111 return gauss; 112 } 113 114 /** 115 * Estimate a multidimensional Gaussian from the data 116 * 117 * @param samples 118 * the data 119 * @return the Gaussian with the best fit to the data 120 */ 121 public static MultivariateGaussian estimate(Matrix samples) { 122 return estimate(samples.getArray()); 123 } 124 125 /** 126 * Estimate a multidimensional Gaussian from the data 127 * 128 * @param samples 129 * the data 130 * @return the Gaussian with the best fit to the data 131 */ 132 public static MultivariateGaussian estimate(double[][] samples) { 133 final int ndims = samples[0].length; 134 135 final CachingMultivariateGaussian gauss = new CachingMultivariateGaussian(); 136 gauss.N = ndims; 137 138 final MeanAndCovariance res = new MeanAndCovariance(samples); 139 gauss.mean = res.mean; 140 gauss.covar = res.covar; 141 142 gauss.cacheValues(); 143 144 return gauss; 145 } 146 147 /** 148 * Get the probability for a given point in space relative to the PDF 149 * represented by this Gaussian. 150 * 151 * @param sample 152 * the point 153 * @return the probability 154 */ 155 @Override 156 public double estimateProbability(double[] sample) { 157 final Matrix xm = new Matrix(1, N); 158 for (int i = 0; i < N; i++) 159 xm.set(0, i, sample[i] - mean.get(0, i)); 160 161 final Matrix xmt = xm.transpose(); 162 163 final double v = xm.times(inv_covar.times(xmt)).get(0, 0); 164 165 return pdf_const_factor * Math.exp(-0.5 * v); 166 } 167 168 /** 169 * Get the probability for a given point in space relative to the PDF 170 * represented by this Gaussian. 171 * 172 * @param sample 173 * the point 174 * @return the probability 175 */ 176 public double estimateProbability(Float[] sample) { 177 final Matrix xm = new Matrix(1, N); 178 for (int i = 0; i < N; i++) 179 xm.set(0, i, sample[i] - mean.get(0, i)); 180 181 final Matrix xmt = xm.transpose(); 182 183 final double v = xm.times(inv_covar.times(xmt)).get(0, 0); 184 185 return pdf_const_factor * Math.exp(-0.5 * v); 186 } 187 188 @Override 189 public double estimateLogProbability(double[] sample) { 190 final Matrix xm = new Matrix(1, N); 191 for (int i = 0; i < N; i++) 192 xm.set(0, i, sample[i] - mean.get(0, i)); 193 194 final Matrix xmt = xm.transpose(); 195 196 final double v = xm.times(inv_covar.times(xmt)).get(0, 0); 197 198 return Math.log(pdf_const_factor) + (-0.5 * v); 199 } 200 201 /* 202 * (non-Javadoc) 203 * 204 * @see 205 * org.openimaj.math.statistics.distribution.MultivariateGaussian#getCovariance 206 * () 207 */ 208 @Override 209 public Matrix getCovariance() { 210 return covar; 211 } 212 213 /* 214 * (non-Javadoc) 215 * 216 * @see 217 * org.openimaj.math.statistics.distribution.MultivariateGaussian#numDims() 218 */ 219 @Override 220 public int numDims() { 221 return N; 222 } 223 224 @Override 225 public double[] sample(Random rng) { 226 final Matrix vec = new Matrix(N, 1); 227 228 for (int i = 0; i < N; i++) 229 vec.set(i, 0, rng.nextGaussian()); 230 231 final Matrix result = this.mean.plus(chol.times(vec).transpose()); 232 233 return result.getArray()[0]; 234 } 235 236 @Override 237 public double[][] sample(int count, Random rng) { 238 final double[][] samples = new double[count][]; 239 240 for (int i = 0; i < count; i++) 241 samples[i] = sample(rng); 242 243 return samples; 244 } 245 246 @Override 247 public double getCovariance(int row, int column) { 248 return this.covar.get(row, column); 249 } 250}