001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.meanshift; 031 032import gnu.trove.procedure.TIntObjectProcedure; 033 034import java.util.List; 035import java.util.Set; 036 037import org.openimaj.math.statistics.distribution.MultivariateKernelDensityEstimate; 038import org.openimaj.util.pair.ObjectDoublePair; 039import org.openimaj.util.set.DisjointSetForest; 040import org.openimaj.util.tree.DoubleKDTree; 041 042/** 043 * Exact mean shift implementation. The mean shift procedure is applied to every 044 * underlying point. This can be quite slow, especially with many points. 045 * 046 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 047 * 048 */ 049public class ExactMeanShift { 050 private int maxIter = 300; 051 052 private MultivariateKernelDensityEstimate kde; 053 private int[] assignments; 054 055 private double[][] modes; 056 private int[] counts; 057 058 /** 059 * Perform the ExactMeanShift operation on the given KDE. 060 * 061 * @param kde 062 */ 063 public ExactMeanShift(MultivariateKernelDensityEstimate kde) { 064 this.kde = kde; 065 066 performMeanShift(); 067 } 068 069 protected void performMeanShift() { 070 final double[][] data = kde.getData(); 071 final double[][] modePerPoint = new double[data.length][]; 072 073 // perform the MS procedure on each point 074 for (int i = 0; i < data.length; i++) { 075 final double[] point = data[i].clone(); 076 077 for (int iter = 0; iter < maxIter; iter++) { 078 if (computeMeanShift(point)) 079 break; 080 } 081 modePerPoint[i] = point; 082 } 083 084 // now need to merge modes that are <bandwidth away 085 mergeModes(modePerPoint); 086 } 087 088 /** 089 * Get the modes 090 * 091 * @return the modes 092 */ 093 public double[][] getModes() { 094 return modes; 095 } 096 097 /** 098 * Get the assignments 099 * 100 * @return the assignments 101 */ 102 public int[] getAssignments() { 103 return assignments; 104 } 105 106 protected void mergeModes(double[][] modePerPoint) { 107 final DisjointSetForest<double[]> forest = new DisjointSetForest<double[]>(); 108 109 for (int i = 0; i < modePerPoint.length; i++) 110 forest.makeSet(modePerPoint[i]); 111 112 final DoubleKDTree tree = new DoubleKDTree(modePerPoint); 113 for (int i = 0; i < modePerPoint.length; i++) { 114 final double[] point = modePerPoint[i]; 115 116 tree.radiusSearch(modePerPoint[i], kde.getScaledBandwidth(), new TIntObjectProcedure<double[]>() { 117 @Override 118 public boolean execute(int a, double[] b) { 119 forest.union(point, b); 120 return true; 121 } 122 }); 123 } 124 125 final Set<Set<double[]>> subsets = forest.getSubsets(); 126 this.assignments = new int[modePerPoint.length]; 127 this.modes = new double[subsets.size()][]; 128 this.counts = new int[subsets.size()]; 129 int current = 0; 130 for (final Set<double[]> s : subsets) { 131 this.modes[current] = new double[modePerPoint[0].length]; 132 133 for (int i = 0; i < modePerPoint.length; i++) { 134 if (s.contains(modePerPoint[i])) { 135 assignments[i] = current; 136 for (int j = 0; j < modes[current].length; j++) { 137 modes[current][j] = modePerPoint[i][j]; 138 } 139 } 140 } 141 this.counts[current] = s.size(); 142 for (int j = 0; j < modes[current].length; j++) { 143 modes[current][j] /= counts[current]; 144 } 145 current++; 146 } 147 } 148 149 protected boolean computeMeanShift(double[] pt) { 150 final List<ObjectDoublePair<double[]>> support = kde.getSupport(pt); 151 152 if (support.size() == 1) { 153 return true; 154 } 155 156 double sum = 0; 157 final double[] out = new double[pt.length]; 158 for (final ObjectDoublePair<double[]> p : support) { 159 sum += p.second; 160 161 for (int j = 0; j < out.length; j++) { 162 out[j] += p.second * p.first[j]; 163 } 164 } 165 166 double dist = 0; 167 for (int j = 0; j < out.length; j++) { 168 out[j] /= sum; 169 dist += (pt[j] - out[j]) * (pt[j] - out[j]); 170 } 171 172 System.arraycopy(out, 0, pt, 0, out.length); 173 174 return dist < 1e-3 * kde.getBandwidth(); 175 } 176}