001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.statistics.distribution;
031
032import java.util.Random;
033
034import org.openimaj.math.matrix.MatrixUtils;
035
036import Jama.Matrix;
037
038/**
039 * Abstract base class for {@link MultivariateGaussian} implementations
040 * 
041 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
042 * 
043 */
044public abstract class AbstractMultivariateGaussian implements MultivariateGaussian {
045        /**
046         * The mean vector
047         */
048        public Matrix mean;
049
050        @Override
051        public Matrix getMean() {
052                return mean;
053        }
054
055        @Override
056        public double[] sample(Random rng) {
057                final int N = mean.getColumnDimension();
058                final Matrix chol = getCovariance().chol().getL();
059                final Matrix vec = new Matrix(N, 1);
060
061                for (int i = 0; i < N; i++)
062                        vec.set(i, 0, rng.nextGaussian());
063
064                final Matrix result = this.mean.plus(chol.times(vec).transpose());
065
066                return result.getArray()[0];
067        }
068
069        @Override
070        public double[][] sample(int nsamples, Random rng) {
071                if (nsamples == 0)
072                        return new double[0][0];
073
074                final int N = mean.getColumnDimension();
075                final Matrix chol = getCovariance().chol().getL();
076                final Matrix vec = new Matrix(N, nsamples);
077
078                for (int i = 0; i < N; i++)
079                        for (int j = 0; j < nsamples; j++)
080                                vec.set(i, j, rng.nextGaussian());
081
082                final Matrix result = chol.times(vec).transpose();
083                for (int i = 0; i < result.getRowDimension(); i++)
084                        for (int j = 0; j < result.getColumnDimension(); j++)
085                                result.set(i, j, result.get(i, j) + mean.get(0, j));
086
087                return result.getArray();
088        }
089
090        @Override
091        public int numDims() {
092                return mean.getColumnDimension();
093        }
094
095        @Override
096        public double estimateProbability(double[] sample) {
097                final int N = mean.getColumnDimension();
098                final Matrix inv_covar = getCovariance().inverse();
099                final double pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * getCovariance().det()));
100
101                final Matrix xm = new Matrix(1, N);
102                for (int i = 0; i < N; i++)
103                        xm.set(0, i, sample[i] - mean.get(0, i));
104
105                final Matrix xmt = xm.transpose();
106                final double v = xm.times(inv_covar.times(xmt)).get(0, 0);
107
108                return pdf_const_factor * Math.exp(-0.5 * v);
109        }
110
111        @Override
112        public double estimateLogProbability(double[] sample) {
113                final int N = mean.getColumnDimension();
114                final Matrix inv_covar = getCovariance().inverse();
115                final double cov_det = getCovariance().det();
116                final double pdf_const_factor = 1.0 / Math.sqrt((Math.pow((2 * Math.PI), N) * cov_det));
117
118                final Matrix xm = new Matrix(1, N);
119                for (int i = 0; i < N; i++)
120                        xm.set(0, i, sample[i] - mean.get(0, i));
121
122                final Matrix xmt = xm.transpose();
123                final double v = xm.times(inv_covar.times(xmt)).get(0, 0);
124
125                return Math.log(pdf_const_factor) + (-0.5 * v);
126        }
127
128        @Override
129        public String toString() {
130                // only pretty print with low dimensionality
131                if (this.numDims() < 5)
132                        return String.format("MultivariateGaussian[mean=%s,covar=%s]", MatrixUtils.toMatlabString(mean).trim(),
133                                        MatrixUtils.toMatlabString(getCovariance()));
134                return super.toString();
135        }
136
137        @Override
138        public double[] estimateLogProbability(double[][] x) {
139                final double[] lps = new double[x.length];
140                for (int i = 0; i < x.length; i++)
141                        lps[i] = estimateLogProbability(x[i]);
142                return lps;
143        }
144}