1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.clustering.kdtree;
31
32 import org.apache.commons.math.FunctionEvaluationException;
33 import org.apache.commons.math.analysis.MultivariateRealFunction;
34 import org.apache.commons.math.optimization.GoalType;
35 import org.apache.commons.math.optimization.RealPointValuePair;
36 import org.apache.commons.math.optimization.SimpleRealPointChecker;
37 import org.apache.commons.math.optimization.direct.NelderMead;
38 import org.apache.commons.math.stat.descriptive.moment.Mean;
39 import org.openimaj.math.matrix.DiagonalMatrix;
40 import org.openimaj.math.matrix.MatlibMatrixUtils;
41 import org.openimaj.util.array.ArrayUtils;
42 import org.openimaj.util.pair.ObjectDoublePair;
43
44 import scala.actors.threadpool.Arrays;
45 import ch.akuhn.matrix.DenseMatrix;
46 import ch.akuhn.matrix.SparseMatrix;
47 import ch.akuhn.matrix.Vector;
48
49
50
51
52
53
54 public interface SplitDetectionMode{
55
56
57
58
59
60
61
62
63
64 public class OPTIMISED implements SplitDetectionMode {
65
66 private DiagonalMatrix D;
67 private SparseMatrix W;
68 private MEAN mean;
69
70
71
72
73
74 public OPTIMISED(DiagonalMatrix D, SparseMatrix W) {
75 this.D = D;
76 this.W = W;
77 this.mean = new MEAN();
78 }
79 private ObjectDoublePair<double[]> indicator(double[] vec, double d) {
80 double[] ind = new double[vec.length];
81 double sumx = 0;
82 for (int i = 0; i < ind.length; i++) {
83 if(vec[i] > d){
84 ind[i] = 1;
85 sumx ++;
86 }
87 else{
88 ind[i] = -1;
89 }
90 }
91 return ObjectDoublePair.pair(ind, sumx);
92 }
93 @Override
94 public double detect(final double[] vec) {
95 double[] t = {this.mean.detect(vec)};
96 MultivariateRealFunction func = new MultivariateRealFunction() {
97 @Override
98 public double value(double[] x) throws FunctionEvaluationException {
99 ObjectDoublePair<double[]> ind = indicator(vec,x[0]);
100 double sumd = MatlibMatrixUtils.sum(D);
101 double k = ind.second / sumd;
102 double b = k / (1-k);
103 double[][] y = new double[1][vec.length];
104 for (int i = 0; i < vec.length; i++) {
105 y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]);
106 }
107 SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W);
108 Vector yv = Vector.wrap(y[0]);
109 double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0);
110 double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0);
111 return nom/denom;
112 }
113
114 };
115
116
117 RealPointValuePair ret;
118 try {
119 NelderMead nelderMead = new NelderMead();
120 nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1));
121 ret = nelderMead.optimize(func, GoalType.MINIMIZE, t);
122 return ret.getPoint()[0];
123 } catch (Exception e) {
124 e.printStackTrace();
125 System.err.println("Reverting to mean");
126 }
127 return t[0];
128 }
129
130
131 }
132
133
134
135
136 public static class MEDIAN implements SplitDetectionMode{
137 @Override
138 public double detect(double[] col) {
139 double mid = ArrayUtils.quickSelect(col, col.length/2);
140 if(ArrayUtils.minValue(col) == mid)
141 mid += Double.MIN_NORMAL;
142 if(ArrayUtils.maxValue(col) == mid)
143 mid -= Double.MIN_NORMAL;
144 return 0;
145 }
146
147
148 }
149
150
151
152
153
154
155 public static class MEAN implements SplitDetectionMode{
156
157 @Override
158 public double detect(double[] vec) {
159 return new Mean().evaluate(vec);
160 }
161
162
163
164 }
165
166
167
168
169
170
171 public static class VARIABLE_MEDIAN implements SplitDetectionMode{
172
173 private double tolchange;
174
175
176
177 public VARIABLE_MEDIAN() {
178 this.tolchange = 0.0001;
179 }
180
181
182
183
184 public VARIABLE_MEDIAN(double tol) {
185 this.tolchange = tol;
186 }
187
188 @Override
189 public double detect(double[] vec) {
190 Arrays.sort(vec);
191
192 int medInd = vec.length/2;
193 double medVal = vec[medInd];
194 if(vec.length % 2 == 0){
195 medVal += vec[medInd+1];
196 medVal /= 2.;
197 }
198
199
200 boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]);
201 boolean minWithinTol = withinTol(medVal,vec[0]);
202 if(maxWithinTol && minWithinTol)
203 {
204
205 return medVal;
206 }
207
208
209
210 if(maxWithinTol){
211
212 for (int i = medInd; i > 0; i--) {
213 if(!withinTol(vec[i],vec[i-1])){
214 return vec[i];
215 }
216 }
217 }
218 else{
219
220 for (int i = medInd; i < vec.length-1; i++) {
221 if(!withinTol(vec[i],vec[i+1])){
222 return vec[i+1];
223 }
224 }
225 }
226
227
228
229 return 0;
230 }
231
232 private boolean withinTol(double a, double d) {
233 return Math.abs(a - d) / Math.abs(a) < this.tolchange;
234 }
235
236
237
238 };
239
240
241
242
243 public abstract double detect(double[] vec);
244 }