001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.statistics.distribution; 031 032import gnu.trove.procedure.TObjectDoubleProcedure; 033 034import java.util.ArrayList; 035import java.util.List; 036import java.util.Random; 037 038import org.openimaj.math.statistics.distribution.kernel.UnivariateKernel; 039import org.openimaj.util.pair.ObjectDoublePair; 040import org.openimaj.util.tree.DoubleKDTree; 041 042/** 043 * A Parzen window kernel density estimate using a univariate kernel and 044 * Euclidean distance. Uses a KD-Tree to for efficient neighbour search. 045 * 046 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 047 * 048 */ 049public class MultivariateKernelDensityEstimate extends AbstractMultivariateDistribution { 050 double[][] data; 051 UnivariateKernel kernel; 052 private double bandwidth; 053 DoubleKDTree tree; 054 055 /** 056 * Construct with the given data, kernel and bandwidth 057 * 058 * @param data 059 * the data 060 * @param kernel 061 * the kernel 062 * @param bandwidth 063 * the bandwidth 064 */ 065 public MultivariateKernelDensityEstimate(double[][] data, UnivariateKernel kernel, double bandwidth) { 066 this.data = data; 067 this.tree = new DoubleKDTree(data); 068 this.kernel = kernel; 069 this.bandwidth = bandwidth; 070 } 071 072 /** 073 * Construct with the given data, kernel and bandwidth 074 * 075 * @param data 076 * the data 077 * @param kernel 078 * the kernel 079 * @param bandwidth 080 * the bandwidth 081 */ 082 public MultivariateKernelDensityEstimate(List<double[]> data, UnivariateKernel kernel, double bandwidth) 083 { 084 this.data = data.toArray(new double[data.size()][]); 085 this.tree = new DoubleKDTree(this.data); 086 this.kernel = kernel; 087 this.bandwidth = bandwidth; 088 } 089 090 @Override 091 public double[] sample(Random rng) { 092 final double[] pt = data[rng.nextInt(data.length)].clone(); 093 094 for (int i = 0; i < pt.length; i++) { 095 pt[i] = pt[i] + kernel.sample(rng) * this.getBandwidth(); 096 } 097 098 return pt; 099 } 100 101 @Override 102 public double estimateProbability(double[] sample) { 103 final double[] prob = new double[1]; 104 final int[] count = new int[1]; 105 106 tree.coordinateRadiusSearch(sample, kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() { 107 @Override 108 public boolean execute(double[] point, double distance) { 109 prob[0] += kernel.evaluate(Math.sqrt(distance) / getBandwidth()); 110 count[0]++; 111 112 return true; 113 } 114 }); 115 116 return prob[0] / (getBandwidth() * count[0]); 117 } 118 119 /** 120 * Get the underlying points that support the KDE within the window around 121 * the given point. Each point is returned together with its own density 122 * estimate. 123 * 124 * @param sample 125 * the point in the centre of the window 126 * @return the points in the window 127 */ 128 public List<ObjectDoublePair<double[]>> getSupport(double[] sample) { 129 final List<ObjectDoublePair<double[]>> support = new ArrayList<ObjectDoublePair<double[]>>(); 130 131 tree.coordinateRadiusSearch(sample, kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() { 132 @Override 133 public boolean execute(double[] a, double b) { 134 support.add(ObjectDoublePair.pair(a, kernel.evaluate(Math.sqrt(b) / getBandwidth()))); 135 136 return true; 137 } 138 }); 139 140 return support; 141 } 142 143 /** 144 * Get the underlying data 145 * 146 * @return the data 147 */ 148 public double[][] getData() { 149 return data; 150 } 151 152 /** 153 * Get the bandwidth 154 * 155 * @return the bandwidth 156 */ 157 public double getBandwidth() { 158 return bandwidth; 159 } 160 161 /** 162 * Get the bandwidth scaled by the kernel support. 163 * 164 * @see UnivariateKernel#getCutOff() 165 * 166 * @return the scaled bandwidth 167 */ 168 public double getScaledBandwidth() { 169 return bandwidth * this.kernel.getCutOff(); 170 } 171 172 @Override 173 public double[] estimateLogProbability(double[][] x) { 174 final double[] lps = new double[x.length]; 175 for (int i = 0; i < x.length; i++) 176 lps[i] = estimateLogProbability(x[i]); 177 return lps; 178 } 179}