001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.kdtree; 031 032import org.apache.commons.math.FunctionEvaluationException; 033import org.apache.commons.math.analysis.MultivariateRealFunction; 034import org.apache.commons.math.optimization.GoalType; 035import org.apache.commons.math.optimization.RealPointValuePair; 036import org.apache.commons.math.optimization.SimpleRealPointChecker; 037import org.apache.commons.math.optimization.direct.NelderMead; 038import org.apache.commons.math.stat.descriptive.moment.Mean; 039import org.openimaj.math.matrix.DiagonalMatrix; 040import org.openimaj.math.matrix.MatlibMatrixUtils; 041import org.openimaj.util.array.ArrayUtils; 042import org.openimaj.util.pair.ObjectDoublePair; 043 044import scala.actors.threadpool.Arrays; 045import ch.akuhn.matrix.DenseMatrix; 046import ch.akuhn.matrix.SparseMatrix; 047import ch.akuhn.matrix.Vector; 048 049/** 050 * Given a vector, tell me the split 051 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 052 * 053 */ 054public interface SplitDetectionMode{ 055 /** 056 * minimise for y: (y' * (D - W) * y) / ( y' * D * y ); 057 * s.t. y = (1 + x) - b * (1 - x); 058 * s.t. b = k / (1 - k); 059 * s.t. k = sum(d(x > 0)) / sum(d); 060 * and 061 * s.t. x is an indicator (-1 for less than t, 1 for greater than or equal to t) 062 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 063 */ 064 public class OPTIMISED implements SplitDetectionMode { 065 066 private DiagonalMatrix D; 067 private SparseMatrix W; 068 private MEAN mean; 069 070 /** 071 * @param D 072 * @param W 073 */ 074 public OPTIMISED(DiagonalMatrix D, SparseMatrix W) { 075 this.D = D; 076 this.W = W; 077 this.mean = new MEAN(); 078 } 079 private ObjectDoublePair<double[]> indicator(double[] vec, double d) { 080 double[] ind = new double[vec.length]; 081 double sumx = 0; 082 for (int i = 0; i < ind.length; i++) { 083 if(vec[i] > d){ 084 ind[i] = 1; 085 sumx ++; 086 } 087 else{ 088 ind[i] = -1; 089 } 090 } 091 return ObjectDoublePair.pair(ind, sumx); 092 } 093 @Override 094 public double detect(final double[] vec) { 095 double[] t = {this.mean.detect(vec)}; 096 MultivariateRealFunction func = new MultivariateRealFunction() { 097 @Override 098 public double value(double[] x) throws FunctionEvaluationException { 099 ObjectDoublePair<double[]> ind = indicator(vec,x[0]); 100 double sumd = MatlibMatrixUtils.sum(D); 101 double k = ind.second / sumd; 102 double b = k / (1-k); 103 double[][] y = new double[1][vec.length]; 104 for (int i = 0; i < vec.length; i++) { 105 y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]); 106 } 107 SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W); 108 Vector yv = Vector.wrap(y[0]); 109 double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0); // y' * ( (D-W) * y) 110 double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0); 111 return nom/denom; 112 } 113 114 }; 115 116// 117 RealPointValuePair ret; 118 try { 119 NelderMead nelderMead = new NelderMead(); 120 nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1)); 121 ret = nelderMead.optimize(func, GoalType.MINIMIZE, t); 122 return ret.getPoint()[0]; 123 } catch (Exception e) { 124 e.printStackTrace(); 125 System.err.println("Reverting to mean"); 126 } 127 return t[0]; 128 } 129 130 131 } 132 133 /** 134 * Splits clusters becuase they don't have exactly the same value! 135 */ 136 public static class MEDIAN implements SplitDetectionMode{ 137 @Override 138 public double detect(double[] col) { 139 double mid = ArrayUtils.quickSelect(col, col.length/2); 140 if(ArrayUtils.minValue(col) == mid) 141 mid += Double.MIN_NORMAL; 142 if(ArrayUtils.maxValue(col) == mid) 143 mid -= Double.MIN_NORMAL; 144 return 0; 145 } 146 147 148 } 149 150 /** 151 * Use the mean to split 152 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 153 * 154 */ 155 public static class MEAN implements SplitDetectionMode{ 156 157 @Override 158 public double detect(double[] vec) { 159 return new Mean().evaluate(vec); 160 } 161 162 163 164 } 165 166 /** 167 * Find the median, attempt to find a value which keeps clusters together 168 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 169 * 170 */ 171 public static class VARIABLE_MEDIAN implements SplitDetectionMode{ 172 173 private double tolchange; 174 /** 175 * Sets the change tolerance to 0.1 (i.e. if the next value is different by more than value * 0.1, we switch) 176 */ 177 public VARIABLE_MEDIAN() { 178 this.tolchange = 0.0001; 179 } 180 181 /** 182 * @param tol if the next value is different by more than value * tol we found a border 183 */ 184 public VARIABLE_MEDIAN(double tol) { 185 this.tolchange = tol; 186 } 187 188 @Override 189 public double detect(double[] vec) { 190 Arrays.sort(vec); 191 // Find the median index 192 int medInd = vec.length/2; 193 double medVal = vec[medInd]; 194 if(vec.length % 2 == 0){ 195 medVal += vec[medInd+1]; 196 medVal /= 2.; 197 } 198 199 200 boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]); 201 boolean minWithinTol = withinTol(medVal,vec[0]); 202 if(maxWithinTol && minWithinTol) 203 { 204 // degenerate case, the min and max are not beyond the tolerance, return the median 205 return medVal; 206 } 207 // The split works like: 208 // < val go left 209 // >= val go right 210 if(maxWithinTol){ 211 // search left 212 for (int i = medInd; i > 0; i--) { 213 if(!withinTol(vec[i],vec[i-1])){ 214 return vec[i]; 215 } 216 } 217 } 218 else{ 219 // search right 220 for (int i = medInd; i < vec.length-1; i++) { 221 if(!withinTol(vec[i],vec[i+1])){ 222 return vec[i+1]; 223 } 224 } 225 } 226 227 228 229 return 0; 230 } 231 232 private boolean withinTol(double a, double d) { 233 return Math.abs(a - d) / Math.abs(a) < this.tolchange; 234 } 235 236 237 238 }; 239 /** 240 * @param vec 241 * @return find the split point 242 */ 243 public abstract double detect(double[] vec); 244}