001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.statistics.distribution;
031
032import gnu.trove.procedure.TObjectDoubleProcedure;
033
034import java.util.ArrayList;
035import java.util.List;
036import java.util.Random;
037
038import org.openimaj.math.statistics.distribution.kernel.UnivariateKernel;
039import org.openimaj.util.pair.ObjectDoublePair;
040import org.openimaj.util.tree.DoubleKDTree;
041
042/**
043 * A Parzen window kernel density estimate using a univariate kernel and
044 * Euclidean distance. Uses a KD-Tree to for efficient neighbour search.
045 * 
046 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
047 * 
048 */
049public class MultivariateKernelDensityEstimate extends AbstractMultivariateDistribution {
050        double[][] data;
051        UnivariateKernel kernel;
052        private double bandwidth;
053        DoubleKDTree tree;
054
055        /**
056         * Construct with the given data, kernel and bandwidth
057         * 
058         * @param data
059         *            the data
060         * @param kernel
061         *            the kernel
062         * @param bandwidth
063         *            the bandwidth
064         */
065        public MultivariateKernelDensityEstimate(double[][] data, UnivariateKernel kernel, double bandwidth) {
066                this.data = data;
067                this.tree = new DoubleKDTree(data);
068                this.kernel = kernel;
069                this.bandwidth = bandwidth;
070        }
071
072        /**
073         * Construct with the given data, kernel and bandwidth
074         * 
075         * @param data
076         *            the data
077         * @param kernel
078         *            the kernel
079         * @param bandwidth
080         *            the bandwidth
081         */
082        public MultivariateKernelDensityEstimate(List<double[]> data, UnivariateKernel kernel, double bandwidth)
083        {
084                this.data = data.toArray(new double[data.size()][]);
085                this.tree = new DoubleKDTree(this.data);
086                this.kernel = kernel;
087                this.bandwidth = bandwidth;
088        }
089
090        @Override
091        public double[] sample(Random rng) {
092                final double[] pt = data[rng.nextInt(data.length)].clone();
093
094                for (int i = 0; i < pt.length; i++) {
095                        pt[i] = pt[i] + kernel.sample(rng) * this.getBandwidth();
096                }
097
098                return pt;
099        }
100
101        @Override
102        public double estimateProbability(double[] sample) {
103                final double[] prob = new double[1];
104                final int[] count = new int[1];
105
106                tree.coordinateRadiusSearch(sample, kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() {
107                        @Override
108                        public boolean execute(double[] point, double distance) {
109                                prob[0] += kernel.evaluate(Math.sqrt(distance) / getBandwidth());
110                                count[0]++;
111
112                                return true;
113                        }
114                });
115
116                return prob[0] / (getBandwidth() * count[0]);
117        }
118
119        /**
120         * Get the underlying points that support the KDE within the window around
121         * the given point. Each point is returned together with its own density
122         * estimate.
123         * 
124         * @param sample
125         *            the point in the centre of the window
126         * @return the points in the window
127         */
128        public List<ObjectDoublePair<double[]>> getSupport(double[] sample) {
129                final List<ObjectDoublePair<double[]>> support = new ArrayList<ObjectDoublePair<double[]>>();
130
131                tree.coordinateRadiusSearch(sample, kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() {
132                        @Override
133                        public boolean execute(double[] a, double b) {
134                                support.add(ObjectDoublePair.pair(a, kernel.evaluate(Math.sqrt(b) / getBandwidth())));
135
136                                return true;
137                        }
138                });
139
140                return support;
141        }
142
143        /**
144         * Get the underlying data
145         * 
146         * @return the data
147         */
148        public double[][] getData() {
149                return data;
150        }
151
152        /**
153         * Get the bandwidth
154         * 
155         * @return the bandwidth
156         */
157        public double getBandwidth() {
158                return bandwidth;
159        }
160
161        /**
162         * Get the bandwidth scaled by the kernel support.
163         * 
164         * @see UnivariateKernel#getCutOff()
165         * 
166         * @return the scaled bandwidth
167         */
168        public double getScaledBandwidth() {
169                return bandwidth * this.kernel.getCutOff();
170        }
171
172        @Override
173        public double[] estimateLogProbability(double[][] x) {
174                final double[] lps = new double[x.length];
175                for (int i = 0; i < x.length; i++)
176                        lps[i] = estimateLogProbability(x[i]);
177                return lps;
178        }
179}