001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.incremental;
031
032import gnu.trove.list.array.TIntArrayList;
033import gnu.trove.set.TIntSet;
034import gnu.trove.set.hash.TIntHashSet;
035
036import java.util.ArrayList;
037import java.util.HashMap;
038import java.util.List;
039import java.util.Map;
040import java.util.Map.Entry;
041
042import org.apache.logging.log4j.Logger;
043import org.apache.logging.log4j.LogManager;
044
045import org.openimaj.experiment.evaluation.cluster.analyser.FScoreClusterAnalyser;
046import org.openimaj.math.matrix.MatlibMatrixUtils;
047import org.openimaj.ml.clustering.IndexClusters;
048import org.openimaj.ml.clustering.SparseMatrixClusterer;
049import org.openimaj.util.pair.IntDoublePair;
050
051import ch.akuhn.matrix.SparseMatrix;
052
053/**
054 *
055 * An incremental clusterer which holds old {@link SparseMatrix} instances internally, 
056 * only forgetting rows once they have been clustered and are relatively stable.
057 * 
058 * The criteria for row removal is cluster stability.
059 * The defenition of cluster stability is maximum f1-score achieving a threshold between
060 * clusters in the previous round and the current round. Once one round of stability is achieved
061 * the cluster is stable and its elements are removed.
062 * 
063 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
064 */
065public class IncrementalSparseClusterer implements SparseMatrixClusterer<IndexClusters>{
066        
067        private SparseMatrixClusterer<? extends IndexClusters> clusterer;
068        private int window;
069        protected double threshold;
070        private int maxwindow = -1;
071        private final static Logger logger = LogManager.getLogger(IncrementalSparseClusterer.class);
072        
073
074        /**
075         * @param clusterer the underlying clusterer
076         * @param window 
077         */
078        public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window) {
079                this.clusterer = clusterer;
080                this.window = window;
081                this.threshold = 1.;
082        }
083        
084        /**
085         * @param clusterer the underlying clusterer
086         * @param window 
087         * @param threshold 
088         */
089        public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, double threshold) {
090                this.clusterer = clusterer;
091                this.window = window;
092                this.threshold = threshold;
093        }
094        
095        /**
096         * @param clusterer the underlying clusterer
097         * @param window 
098         * @param maxwindow 
099         */
100        @SuppressWarnings("unused")
101        private IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, int maxwindow) {
102                this.clusterer = clusterer;
103                this.window = window;
104                if(maxwindow>0 ){
105                        if(maxwindow < window * 2)
106                                maxwindow = window * 2;
107                }
108                if(maxwindow <= 0){
109                        maxwindow = -1;
110                }
111                this.maxwindow  = maxwindow;
112                this.threshold = 1.;
113        }
114        
115        class WindowedSparseMatrix{
116                SparseMatrix window;
117                Map<Integer,Integer> indexCorrection;
118                
119                public WindowedSparseMatrix(SparseMatrix sm, int nextwindow, TIntSet inactive) {
120                        TIntArrayList active = new TIntArrayList(nextwindow);
121                        indexCorrection = new HashMap<Integer, Integer>();
122                        for (int i = 0; i < nextwindow; i++) {
123                                if(!inactive.contains(i)){
124                                        indexCorrection.put(active.size(), i);
125                                        active.add(i);
126                                }
127                        }
128                        window = MatlibMatrixUtils.subMatrix(sm, active, active);
129                }
130                
131                public void correctClusters(IndexClusters clstrs){
132                        int[][] clusters = clstrs.clusters();
133                        for (int i = 0; i < clusters.length; i++) {
134                                int[] cluster = clusters[i];
135                                for (int j = 0; j < cluster.length; j++) {
136                                        cluster[j] = indexCorrection.get(cluster[j]);
137                                }
138                        }
139                }
140        }
141
142        @Override
143        public IndexClusters cluster(SparseMatrix data) {
144                if(window >= data.rowCount()) window = data.rowCount();
145                SparseMatrix seen = MatlibMatrixUtils.subMatrix(data, 0, window, 0, window);
146                int seenrows = window;
147                TIntSet inactiveRows = new TIntHashSet(window);
148                logger.debug("First clustering!: " + seen.rowCount() + "x" + seen.columnCount());
149                IndexClusters oldClusters = clusterer.cluster(seen);
150                logger.debug("First clusters:\n" + oldClusters);
151                List<int[]> completedClusters = new ArrayList<int[]>();
152                while(seenrows < data.rowCount()){
153                        int nextwindow = seenrows + window;
154                        if(nextwindow >= data.rowCount()) nextwindow = data.rowCount();
155                        if(this.maxwindow > 0 && nextwindow - inactiveRows.size() > this.maxwindow){
156                                logger.debug(String.format("Window size (%d) without inactive (%d) = (%d), greater than maximum (%d)",nextwindow, inactiveRows.size(), nextwindow - inactiveRows.size(), this.maxwindow));
157                                deactiveOldItemsAsNoise(nextwindow,inactiveRows,completedClusters);
158                        }
159                        WindowedSparseMatrix wsp = new WindowedSparseMatrix(data, nextwindow, inactiveRows);
160                        logger.debug("Clustering: " + wsp.window.rowCount() + "x" + wsp.window.columnCount());
161                        IndexClusters newClusters = clusterer.cluster(wsp.window);
162                        wsp.correctClusters(newClusters);
163                        logger.debug("New clusters:\n" + newClusters);
164                        // if stability == 1 for any cluster, it was the same last window, we should not include those items next round
165                        detectInactive(oldClusters, newClusters, inactiveRows, completedClusters);
166                        
167                        oldClusters = newClusters;
168                        seenrows += window;
169                        logger.debug("Seen rows: " + seenrows);
170                        logger.debug("Inactive rows: " + inactiveRows.size());
171                }
172                for (int i = 0; i < oldClusters.clusters().length; i++) {
173                        int[] cluster = oldClusters.clusters()[i];
174                        if(cluster.length!=0) 
175                                completedClusters.add(cluster);
176                }
177                
178                return new IndexClusters(completedClusters);
179        }
180
181        private void deactiveOldItemsAsNoise(int nextwindow, TIntSet inactiveRows, List<int[]> completedClusters) {
182                int toDeactivate = 0;
183                while(nextwindow - inactiveRows.size() > this.maxwindow){
184                        if(!inactiveRows.contains(toDeactivate)){
185                                logger.debug("Forcing the deactivation of: " + toDeactivate);
186                                inactiveRows.add(toDeactivate);
187                                completedClusters.add(new int[]{toDeactivate});
188                        }
189                        toDeactivate++;
190                }
191        }
192
193        /**
194         * Given the old and new clusters, make a decision as to which rows are now inactive,
195         * and therefore which clusters are now completed
196         * @param oldClusters
197         * @param newClusters
198         * @param inactiveRows
199         * @param completedClusters
200         */
201        protected void detectInactive(IndexClusters oldClusters, IndexClusters newClusters, TIntSet inactiveRows, List<int[]> completedClusters) {
202                Map<Integer, IntDoublePair> stability = calculateStability(oldClusters,newClusters,inactiveRows);
203                for (Entry<Integer, IntDoublePair> e : stability.entrySet()) {
204                        if(e.getValue().second >= threshold){
205                                int[] completedCluster = oldClusters.clusters()[e.getKey()];
206                                inactiveRows.addAll(completedCluster);
207                                completedClusters.add(completedCluster);
208                                if(threshold == 1){
209                                        newClusters.clusters()[e.getValue().first] = new int[0];
210                                }
211                        }
212                }
213        }
214
215        protected Map<Integer, IntDoublePair> calculateStability(IndexClusters c1, IndexClusters c2, TIntSet inactiveRows) {
216                
217                Map<Integer, IntDoublePair> stability = new HashMap<Integer, IntDoublePair>();
218                int[][] clusters1 = c1.clusters();
219                int[][] clusters2 = c2.clusters();
220                for (int i = 0; i < clusters1.length; i++) {
221                        if(clusters1[i].length == 0) continue;
222                        double maxnmi = 0;
223                        int maxj = -1;
224                        TIntArrayList cluster = new TIntArrayList(clusters1[i].length);
225                        for (int j = 0; j < clusters1[i].length; j++) {
226                                if(inactiveRows.contains(clusters1[i][j]))
227                                        continue;
228                                cluster.add(clusters1[i][j]);
229                        }
230                        int[][] correct = new int[][]{cluster.toArray()};
231                        for (int j = 0; j < clusters2.length; j++) {
232                                int[][] estimated = new int[][]{clusters2[j]};
233//                              NMIAnalysis nmi = new NMIClusterAnalyser().analyse(correct, estimated);
234                                double score = 0;
235                                if(correct[0].length == 1 && estimated[0].length == 1){
236                                        // BOTH 1, either they are the same or not!
237                                        score = correct[0][0] == estimated[0][0] ? 1 : 0;
238                                }
239                                else{                                   
240                                        score = new FScoreClusterAnalyser().analyse(correct, estimated).score();
241                                }
242                                if(!Double.isNaN(score))
243                                {
244                                        if(score > maxnmi){
245                                                maxnmi = score;
246                                                maxj = j;
247                                        }
248                                }
249                        }
250                        stability.put(i, IntDoublePair.pair(maxj, maxnmi));
251                }
252                logger.debug(String.format("The stability is:\n%s",stability));
253                return stability;
254        }
255
256        @Override
257        public int[][] performClustering(SparseMatrix data) {
258                return this.cluster(data).clusters();
259        }
260
261        
262}