001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.statistics.distribution;
031
032import java.util.Random;
033
034import org.openimaj.math.statistics.MeanAndCovariance;
035
036import Jama.Matrix;
037
038/**
039 * A single multidimensional Gaussian. This implementation computes the inverse
040 * and Cholesky decomposition of the covariance matrix and caches them for
041 * efficient sampling and probability computation.
042 * 
043 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
044 * 
045 */
046public class CachingMultivariateGaussian extends AbstractMultivariateGaussian implements MultivariateGaussian {
047        protected Matrix covar;
048        protected int N;
049
050        private Matrix inv_covar;
051        private double pdf_const_factor;
052        private Matrix chol;
053
054        protected CachingMultivariateGaussian() {
055        }
056
057        /**
058         * Construct the Gaussian with the provided center and covariance
059         * 
060         * @param mean
061         *            centre of the Gaussian
062         * @param covar
063         *            covariance of the Gaussian
064         */
065        public CachingMultivariateGaussian(Matrix mean, Matrix covar) {
066                N = mean.getColumnDimension();
067                this.mean = mean;
068                this.covar = covar;
069                cacheValues();
070        }
071
072        /**
073         * Construct the Gaussian with the zero mean and unit variance
074         * 
075         * @param ndims
076         *            number of dimensions
077         */
078        public CachingMultivariateGaussian(int ndims) {
079                N = ndims;
080                this.mean = new Matrix(1, N);
081                this.covar = Matrix.identity(N, N);
082                cacheValues();
083        }
084
085        protected void cacheValues() {
086                inv_covar = covar.inverse();
087                pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * covar.det()));
088
089                chol = covar.chol().getL();
090        }
091
092        /**
093         * Estimate a multidimensional Gaussian from the data
094         * 
095         * @param samples
096         *            the data
097         * @return the Gaussian with the best fit to the data
098         */
099        public static CachingMultivariateGaussian estimate(float[][] samples) {
100                final int ndims = samples[0].length;
101
102                final CachingMultivariateGaussian gauss = new CachingMultivariateGaussian();
103                gauss.N = ndims;
104
105                final MeanAndCovariance res = new MeanAndCovariance(samples);
106                gauss.mean = res.mean;
107                gauss.covar = res.covar;
108
109                gauss.cacheValues();
110
111                return gauss;
112        }
113
114        /**
115         * Estimate a multidimensional Gaussian from the data
116         * 
117         * @param samples
118         *            the data
119         * @return the Gaussian with the best fit to the data
120         */
121        public static MultivariateGaussian estimate(Matrix samples) {
122                return estimate(samples.getArray());
123        }
124
125        /**
126         * Estimate a multidimensional Gaussian from the data
127         * 
128         * @param samples
129         *            the data
130         * @return the Gaussian with the best fit to the data
131         */
132        public static MultivariateGaussian estimate(double[][] samples) {
133                final int ndims = samples[0].length;
134
135                final CachingMultivariateGaussian gauss = new CachingMultivariateGaussian();
136                gauss.N = ndims;
137
138                final MeanAndCovariance res = new MeanAndCovariance(samples);
139                gauss.mean = res.mean;
140                gauss.covar = res.covar;
141
142                gauss.cacheValues();
143
144                return gauss;
145        }
146
147        /**
148         * Get the probability for a given point in space relative to the PDF
149         * represented by this Gaussian.
150         * 
151         * @param sample
152         *            the point
153         * @return the probability
154         */
155        @Override
156        public double estimateProbability(double[] sample) {
157                final Matrix xm = new Matrix(1, N);
158                for (int i = 0; i < N; i++)
159                        xm.set(0, i, sample[i] - mean.get(0, i));
160
161                final Matrix xmt = xm.transpose();
162
163                final double v = xm.times(inv_covar.times(xmt)).get(0, 0);
164
165                return pdf_const_factor * Math.exp(-0.5 * v);
166        }
167
168        /**
169         * Get the probability for a given point in space relative to the PDF
170         * represented by this Gaussian.
171         * 
172         * @param sample
173         *            the point
174         * @return the probability
175         */
176        public double estimateProbability(Float[] sample) {
177                final Matrix xm = new Matrix(1, N);
178                for (int i = 0; i < N; i++)
179                        xm.set(0, i, sample[i] - mean.get(0, i));
180
181                final Matrix xmt = xm.transpose();
182
183                final double v = xm.times(inv_covar.times(xmt)).get(0, 0);
184
185                return pdf_const_factor * Math.exp(-0.5 * v);
186        }
187
188        @Override
189        public double estimateLogProbability(double[] sample) {
190                final Matrix xm = new Matrix(1, N);
191                for (int i = 0; i < N; i++)
192                        xm.set(0, i, sample[i] - mean.get(0, i));
193
194                final Matrix xmt = xm.transpose();
195
196                final double v = xm.times(inv_covar.times(xmt)).get(0, 0);
197
198                return Math.log(pdf_const_factor) + (-0.5 * v);
199        }
200
201        /*
202         * (non-Javadoc)
203         * 
204         * @see
205         * org.openimaj.math.statistics.distribution.MultivariateGaussian#getCovariance
206         * ()
207         */
208        @Override
209        public Matrix getCovariance() {
210                return covar;
211        }
212
213        /*
214         * (non-Javadoc)
215         * 
216         * @see
217         * org.openimaj.math.statistics.distribution.MultivariateGaussian#numDims()
218         */
219        @Override
220        public int numDims() {
221                return N;
222        }
223
224        @Override
225        public double[] sample(Random rng) {
226                final Matrix vec = new Matrix(N, 1);
227
228                for (int i = 0; i < N; i++)
229                        vec.set(i, 0, rng.nextGaussian());
230
231                final Matrix result = this.mean.plus(chol.times(vec).transpose());
232
233                return result.getArray()[0];
234        }
235
236        @Override
237        public double[][] sample(int count, Random rng) {
238                final double[][] samples = new double[count][];
239
240                for (int i = 0; i < count; i++)
241                        samples[i] = sample(rng);
242
243                return samples;
244        }
245
246        @Override
247        public double getCovariance(int row, int column) {
248                return this.covar.get(row, column);
249        }
250}