View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.clustering.kdtree;
31  
32  import org.apache.commons.math.FunctionEvaluationException;
33  import org.apache.commons.math.analysis.MultivariateRealFunction;
34  import org.apache.commons.math.optimization.GoalType;
35  import org.apache.commons.math.optimization.RealPointValuePair;
36  import org.apache.commons.math.optimization.SimpleRealPointChecker;
37  import org.apache.commons.math.optimization.direct.NelderMead;
38  import org.apache.commons.math.stat.descriptive.moment.Mean;
39  import org.openimaj.math.matrix.DiagonalMatrix;
40  import org.openimaj.math.matrix.MatlibMatrixUtils;
41  import org.openimaj.util.array.ArrayUtils;
42  import org.openimaj.util.pair.ObjectDoublePair;
43  
44  import scala.actors.threadpool.Arrays;
45  import ch.akuhn.matrix.DenseMatrix;
46  import ch.akuhn.matrix.SparseMatrix;
47  import ch.akuhn.matrix.Vector;
48  
49  /**
50   * Given a vector, tell me the split
51   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
52   *
53   */
54  public interface SplitDetectionMode{
55  	/**
56  	 * minimise for y: (y' * (D - W) * y) / ( y' * D * y );
57  	 * s.t. y = (1 + x) - b * (1 - x);
58  	 * s.t. b = k / (1 - k);
59  	 * s.t. k = sum(d(x > 0)) / sum(d);
60  	 * and
61  	 * s.t. x is an indicator (-1 for less than t, 1 for greater than or equal to t)
62  	 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
63  	 */
64  	public class OPTIMISED implements SplitDetectionMode {
65  		
66  		private DiagonalMatrix D;
67  		private SparseMatrix W;
68  		private MEAN mean;
69  
70  		/**
71  		 * @param D
72  		 * @param W
73  		 */
74  		public OPTIMISED(DiagonalMatrix D, SparseMatrix W) {
75  			this.D = D;
76  			this.W = W;
77  			this.mean = new MEAN();
78  		}
79  		private ObjectDoublePair<double[]> indicator(double[] vec, double d) {
80  			double[] ind = new double[vec.length];
81  			double sumx = 0;
82  			for (int i = 0; i < ind.length; i++) {
83  				if(vec[i] > d){
84  					ind[i] = 1;
85  					sumx ++;
86  				}
87  				else{
88  					ind[i] = -1;
89  				}
90  			}
91  			return ObjectDoublePair.pair(ind, sumx);
92  		}
93  		@Override
94  		public double detect(final double[] vec) {
95  			double[] t = {this.mean.detect(vec)};
96  			MultivariateRealFunction func = new MultivariateRealFunction() {
97  				@Override
98  				public double value(double[] x) throws FunctionEvaluationException {
99  					ObjectDoublePair<double[]> ind = indicator(vec,x[0]);
100 					double sumd = MatlibMatrixUtils.sum(D);
101 					double k = ind.second / sumd;
102 					double b = k / (1-k);
103 					double[][] y = new double[1][vec.length];
104 					for (int i = 0; i < vec.length; i++) {
105 						y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]);
106 					}
107 					SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W);
108 					Vector yv = Vector.wrap(y[0]);
109 					double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0); // y' * ( (D-W) * y)
110 					double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0);
111 					return nom/denom;
112 				}
113 
114 			}; 
115 			
116 //			
117 			RealPointValuePair ret;
118 			try {
119 				NelderMead nelderMead = new NelderMead();
120 				nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1));
121 				ret = nelderMead.optimize(func, GoalType.MINIMIZE, t);
122 				return ret.getPoint()[0];
123 			} catch (Exception e) {
124 				e.printStackTrace();
125 				System.err.println("Reverting to mean");
126 			}
127 			return t[0];
128 		}
129 		
130 
131 	}
132 
133 	/**
134 	 * Splits clusters becuase they don't have exactly the same value!
135 	 */
136 	public static class MEDIAN implements SplitDetectionMode{
137 		@Override
138 		public double detect(double[] col) {
139 			double mid = ArrayUtils.quickSelect(col, col.length/2);
140 			if(ArrayUtils.minValue(col) == mid) 
141 				mid += Double.MIN_NORMAL;
142 			if(ArrayUtils.maxValue(col) == mid) 
143 				mid -= Double.MIN_NORMAL;
144 			return 0;
145 		}
146 
147 		
148 	}
149 	
150 	/**
151 	 * Use the mean to split
152 	 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
153 	 *
154 	 */
155 	public static class MEAN implements SplitDetectionMode{
156 
157 		@Override
158 		public double detect(double[] vec) {
159 			return new Mean().evaluate(vec);
160 		}
161 
162 		
163 		
164 	}
165 	
166 	/**
167 	 * Find the median, attempt to find a value which keeps clusters together
168 	 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
169 	 *
170 	 */
171 	public static class VARIABLE_MEDIAN implements SplitDetectionMode{
172 
173 		private double tolchange;
174 		/**
175 		 * Sets the change tolerance to 0.1 (i.e. if the next value is different by more than value * 0.1, we switch)
176 		 */
177 		public VARIABLE_MEDIAN() {
178 			this.tolchange = 0.0001;
179 		}
180 		
181 		/**
182 		 * @param tol if the next value is different by more than value * tol we found a border
183 		 */
184 		public VARIABLE_MEDIAN(double tol) {
185 			this.tolchange = tol;
186 		}
187 		
188 		@Override
189 		public double detect(double[] vec) {
190 			Arrays.sort(vec);
191 			// Find the median index
192 			int medInd = vec.length/2;
193 			double medVal = vec[medInd];
194 			if(vec.length % 2 == 0){
195 				medVal += vec[medInd+1];
196 				medVal /= 2.;
197 			}
198 			
199 			
200 			boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]);
201 			boolean minWithinTol = withinTol(medVal,vec[0]);
202 			if(maxWithinTol && minWithinTol) 
203 			{
204 				// degenerate case, the min and max are not beyond the tolerance, return the median
205 				return medVal;
206 			}
207 			// The split works like:
208 			// < val go left
209 			// >= val go right
210 			if(maxWithinTol){
211 				// search left
212 				for (int i = medInd; i > 0; i--) {
213 					if(!withinTol(vec[i],vec[i-1])){
214 						return vec[i];
215 					}
216 				}
217 			}
218 			else{
219 				// search right
220 				for (int i = medInd; i < vec.length-1; i++) {
221 					if(!withinTol(vec[i],vec[i+1])){
222 						return vec[i+1];
223 					}
224 				}
225 			}
226 			
227 			
228 			
229 			return 0;
230 		}
231 
232 		private boolean withinTol(double a, double d) {
233 			return Math.abs(a - d) / Math.abs(a) < this.tolchange;
234 		}
235 
236 		
237 		
238 	};
239 	/**
240 	 * @param vec
241 	 * @return find the split point
242 	 */
243 	public abstract double detect(double[] vec);
244 }