001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.kdtree;
031
032import org.apache.commons.math.FunctionEvaluationException;
033import org.apache.commons.math.analysis.MultivariateRealFunction;
034import org.apache.commons.math.optimization.GoalType;
035import org.apache.commons.math.optimization.RealPointValuePair;
036import org.apache.commons.math.optimization.SimpleRealPointChecker;
037import org.apache.commons.math.optimization.direct.NelderMead;
038import org.apache.commons.math.stat.descriptive.moment.Mean;
039import org.openimaj.math.matrix.DiagonalMatrix;
040import org.openimaj.math.matrix.MatlibMatrixUtils;
041import org.openimaj.util.array.ArrayUtils;
042import org.openimaj.util.pair.ObjectDoublePair;
043
044import scala.actors.threadpool.Arrays;
045import ch.akuhn.matrix.DenseMatrix;
046import ch.akuhn.matrix.SparseMatrix;
047import ch.akuhn.matrix.Vector;
048
049/**
050 * Given a vector, tell me the split
051 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
052 *
053 */
054public interface SplitDetectionMode{
055        /**
056         * minimise for y: (y' * (D - W) * y) / ( y' * D * y );
057         * s.t. y = (1 + x) - b * (1 - x);
058         * s.t. b = k / (1 - k);
059         * s.t. k = sum(d(x > 0)) / sum(d);
060         * and
061         * s.t. x is an indicator (-1 for less than t, 1 for greater than or equal to t)
062         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
063         */
064        public class OPTIMISED implements SplitDetectionMode {
065                
066                private DiagonalMatrix D;
067                private SparseMatrix W;
068                private MEAN mean;
069
070                /**
071                 * @param D
072                 * @param W
073                 */
074                public OPTIMISED(DiagonalMatrix D, SparseMatrix W) {
075                        this.D = D;
076                        this.W = W;
077                        this.mean = new MEAN();
078                }
079                private ObjectDoublePair<double[]> indicator(double[] vec, double d) {
080                        double[] ind = new double[vec.length];
081                        double sumx = 0;
082                        for (int i = 0; i < ind.length; i++) {
083                                if(vec[i] > d){
084                                        ind[i] = 1;
085                                        sumx ++;
086                                }
087                                else{
088                                        ind[i] = -1;
089                                }
090                        }
091                        return ObjectDoublePair.pair(ind, sumx);
092                }
093                @Override
094                public double detect(final double[] vec) {
095                        double[] t = {this.mean.detect(vec)};
096                        MultivariateRealFunction func = new MultivariateRealFunction() {
097                                @Override
098                                public double value(double[] x) throws FunctionEvaluationException {
099                                        ObjectDoublePair<double[]> ind = indicator(vec,x[0]);
100                                        double sumd = MatlibMatrixUtils.sum(D);
101                                        double k = ind.second / sumd;
102                                        double b = k / (1-k);
103                                        double[][] y = new double[1][vec.length];
104                                        for (int i = 0; i < vec.length; i++) {
105                                                y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]);
106                                        }
107                                        SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W);
108                                        Vector yv = Vector.wrap(y[0]);
109                                        double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0); // y' * ( (D-W) * y)
110                                        double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0);
111                                        return nom/denom;
112                                }
113
114                        }; 
115                        
116//                      
117                        RealPointValuePair ret;
118                        try {
119                                NelderMead nelderMead = new NelderMead();
120                                nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1));
121                                ret = nelderMead.optimize(func, GoalType.MINIMIZE, t);
122                                return ret.getPoint()[0];
123                        } catch (Exception e) {
124                                e.printStackTrace();
125                                System.err.println("Reverting to mean");
126                        }
127                        return t[0];
128                }
129                
130
131        }
132
133        /**
134         * Splits clusters becuase they don't have exactly the same value!
135         */
136        public static class MEDIAN implements SplitDetectionMode{
137                @Override
138                public double detect(double[] col) {
139                        double mid = ArrayUtils.quickSelect(col, col.length/2);
140                        if(ArrayUtils.minValue(col) == mid) 
141                                mid += Double.MIN_NORMAL;
142                        if(ArrayUtils.maxValue(col) == mid) 
143                                mid -= Double.MIN_NORMAL;
144                        return 0;
145                }
146
147                
148        }
149        
150        /**
151         * Use the mean to split
152         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
153         *
154         */
155        public static class MEAN implements SplitDetectionMode{
156
157                @Override
158                public double detect(double[] vec) {
159                        return new Mean().evaluate(vec);
160                }
161
162                
163                
164        }
165        
166        /**
167         * Find the median, attempt to find a value which keeps clusters together
168         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
169         *
170         */
171        public static class VARIABLE_MEDIAN implements SplitDetectionMode{
172
173                private double tolchange;
174                /**
175                 * Sets the change tolerance to 0.1 (i.e. if the next value is different by more than value * 0.1, we switch)
176                 */
177                public VARIABLE_MEDIAN() {
178                        this.tolchange = 0.0001;
179                }
180                
181                /**
182                 * @param tol if the next value is different by more than value * tol we found a border
183                 */
184                public VARIABLE_MEDIAN(double tol) {
185                        this.tolchange = tol;
186                }
187                
188                @Override
189                public double detect(double[] vec) {
190                        Arrays.sort(vec);
191                        // Find the median index
192                        int medInd = vec.length/2;
193                        double medVal = vec[medInd];
194                        if(vec.length % 2 == 0){
195                                medVal += vec[medInd+1];
196                                medVal /= 2.;
197                        }
198                        
199                        
200                        boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]);
201                        boolean minWithinTol = withinTol(medVal,vec[0]);
202                        if(maxWithinTol && minWithinTol) 
203                        {
204                                // degenerate case, the min and max are not beyond the tolerance, return the median
205                                return medVal;
206                        }
207                        // The split works like:
208                        // < val go left
209                        // >= val go right
210                        if(maxWithinTol){
211                                // search left
212                                for (int i = medInd; i > 0; i--) {
213                                        if(!withinTol(vec[i],vec[i-1])){
214                                                return vec[i];
215                                        }
216                                }
217                        }
218                        else{
219                                // search right
220                                for (int i = medInd; i < vec.length-1; i++) {
221                                        if(!withinTol(vec[i],vec[i+1])){
222                                                return vec[i+1];
223                                        }
224                                }
225                        }
226                        
227                        
228                        
229                        return 0;
230                }
231
232                private boolean withinTol(double a, double d) {
233                        return Math.abs(a - d) / Math.abs(a) < this.tolchange;
234                }
235
236                
237                
238        };
239        /**
240         * @param vec
241         * @return find the split point
242         */
243        public abstract double detect(double[] vec);
244}