001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.incremental; 031 032import gnu.trove.list.array.TIntArrayList; 033import gnu.trove.set.TIntSet; 034import gnu.trove.set.hash.TIntHashSet; 035 036import java.util.ArrayList; 037import java.util.HashMap; 038import java.util.List; 039import java.util.Map; 040import java.util.Map.Entry; 041 042import org.apache.logging.log4j.Logger; 043import org.apache.logging.log4j.LogManager; 044 045import org.openimaj.experiment.evaluation.cluster.analyser.FScoreClusterAnalyser; 046import org.openimaj.math.matrix.MatlibMatrixUtils; 047import org.openimaj.ml.clustering.IndexClusters; 048import org.openimaj.ml.clustering.SparseMatrixClusterer; 049import org.openimaj.util.pair.IntDoublePair; 050 051import ch.akuhn.matrix.SparseMatrix; 052 053/** 054 * 055 * An incremental clusterer which holds old {@link SparseMatrix} instances internally, 056 * only forgetting rows once they have been clustered and are relatively stable. 057 * 058 * The criteria for row removal is cluster stability. 059 * The defenition of cluster stability is maximum f1-score achieving a threshold between 060 * clusters in the previous round and the current round. Once one round of stability is achieved 061 * the cluster is stable and its elements are removed. 062 * 063 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 064 */ 065public class IncrementalSparseClusterer implements SparseMatrixClusterer<IndexClusters>{ 066 067 private SparseMatrixClusterer<? extends IndexClusters> clusterer; 068 private int window; 069 protected double threshold; 070 private int maxwindow = -1; 071 private final static Logger logger = LogManager.getLogger(IncrementalSparseClusterer.class); 072 073 074 /** 075 * @param clusterer the underlying clusterer 076 * @param window 077 */ 078 public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window) { 079 this.clusterer = clusterer; 080 this.window = window; 081 this.threshold = 1.; 082 } 083 084 /** 085 * @param clusterer the underlying clusterer 086 * @param window 087 * @param threshold 088 */ 089 public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, double threshold) { 090 this.clusterer = clusterer; 091 this.window = window; 092 this.threshold = threshold; 093 } 094 095 /** 096 * @param clusterer the underlying clusterer 097 * @param window 098 * @param maxwindow 099 */ 100 @SuppressWarnings("unused") 101 private IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, int maxwindow) { 102 this.clusterer = clusterer; 103 this.window = window; 104 if(maxwindow>0 ){ 105 if(maxwindow < window * 2) 106 maxwindow = window * 2; 107 } 108 if(maxwindow <= 0){ 109 maxwindow = -1; 110 } 111 this.maxwindow = maxwindow; 112 this.threshold = 1.; 113 } 114 115 class WindowedSparseMatrix{ 116 SparseMatrix window; 117 Map<Integer,Integer> indexCorrection; 118 119 public WindowedSparseMatrix(SparseMatrix sm, int nextwindow, TIntSet inactive) { 120 TIntArrayList active = new TIntArrayList(nextwindow); 121 indexCorrection = new HashMap<Integer, Integer>(); 122 for (int i = 0; i < nextwindow; i++) { 123 if(!inactive.contains(i)){ 124 indexCorrection.put(active.size(), i); 125 active.add(i); 126 } 127 } 128 window = MatlibMatrixUtils.subMatrix(sm, active, active); 129 } 130 131 public void correctClusters(IndexClusters clstrs){ 132 int[][] clusters = clstrs.clusters(); 133 for (int i = 0; i < clusters.length; i++) { 134 int[] cluster = clusters[i]; 135 for (int j = 0; j < cluster.length; j++) { 136 cluster[j] = indexCorrection.get(cluster[j]); 137 } 138 } 139 } 140 } 141 142 @Override 143 public IndexClusters cluster(SparseMatrix data) { 144 if(window >= data.rowCount()) window = data.rowCount(); 145 SparseMatrix seen = MatlibMatrixUtils.subMatrix(data, 0, window, 0, window); 146 int seenrows = window; 147 TIntSet inactiveRows = new TIntHashSet(window); 148 logger.debug("First clustering!: " + seen.rowCount() + "x" + seen.columnCount()); 149 IndexClusters oldClusters = clusterer.cluster(seen); 150 logger.debug("First clusters:\n" + oldClusters); 151 List<int[]> completedClusters = new ArrayList<int[]>(); 152 while(seenrows < data.rowCount()){ 153 int nextwindow = seenrows + window; 154 if(nextwindow >= data.rowCount()) nextwindow = data.rowCount(); 155 if(this.maxwindow > 0 && nextwindow - inactiveRows.size() > this.maxwindow){ 156 logger.debug(String.format("Window size (%d) without inactive (%d) = (%d), greater than maximum (%d)",nextwindow, inactiveRows.size(), nextwindow - inactiveRows.size(), this.maxwindow)); 157 deactiveOldItemsAsNoise(nextwindow,inactiveRows,completedClusters); 158 } 159 WindowedSparseMatrix wsp = new WindowedSparseMatrix(data, nextwindow, inactiveRows); 160 logger.debug("Clustering: " + wsp.window.rowCount() + "x" + wsp.window.columnCount()); 161 IndexClusters newClusters = clusterer.cluster(wsp.window); 162 wsp.correctClusters(newClusters); 163 logger.debug("New clusters:\n" + newClusters); 164 // if stability == 1 for any cluster, it was the same last window, we should not include those items next round 165 detectInactive(oldClusters, newClusters, inactiveRows, completedClusters); 166 167 oldClusters = newClusters; 168 seenrows += window; 169 logger.debug("Seen rows: " + seenrows); 170 logger.debug("Inactive rows: " + inactiveRows.size()); 171 } 172 for (int i = 0; i < oldClusters.clusters().length; i++) { 173 int[] cluster = oldClusters.clusters()[i]; 174 if(cluster.length!=0) 175 completedClusters.add(cluster); 176 } 177 178 return new IndexClusters(completedClusters); 179 } 180 181 private void deactiveOldItemsAsNoise(int nextwindow, TIntSet inactiveRows, List<int[]> completedClusters) { 182 int toDeactivate = 0; 183 while(nextwindow - inactiveRows.size() > this.maxwindow){ 184 if(!inactiveRows.contains(toDeactivate)){ 185 logger.debug("Forcing the deactivation of: " + toDeactivate); 186 inactiveRows.add(toDeactivate); 187 completedClusters.add(new int[]{toDeactivate}); 188 } 189 toDeactivate++; 190 } 191 } 192 193 /** 194 * Given the old and new clusters, make a decision as to which rows are now inactive, 195 * and therefore which clusters are now completed 196 * @param oldClusters 197 * @param newClusters 198 * @param inactiveRows 199 * @param completedClusters 200 */ 201 protected void detectInactive(IndexClusters oldClusters, IndexClusters newClusters, TIntSet inactiveRows, List<int[]> completedClusters) { 202 Map<Integer, IntDoublePair> stability = calculateStability(oldClusters,newClusters,inactiveRows); 203 for (Entry<Integer, IntDoublePair> e : stability.entrySet()) { 204 if(e.getValue().second >= threshold){ 205 int[] completedCluster = oldClusters.clusters()[e.getKey()]; 206 inactiveRows.addAll(completedCluster); 207 completedClusters.add(completedCluster); 208 if(threshold == 1){ 209 newClusters.clusters()[e.getValue().first] = new int[0]; 210 } 211 } 212 } 213 } 214 215 protected Map<Integer, IntDoublePair> calculateStability(IndexClusters c1, IndexClusters c2, TIntSet inactiveRows) { 216 217 Map<Integer, IntDoublePair> stability = new HashMap<Integer, IntDoublePair>(); 218 int[][] clusters1 = c1.clusters(); 219 int[][] clusters2 = c2.clusters(); 220 for (int i = 0; i < clusters1.length; i++) { 221 if(clusters1[i].length == 0) continue; 222 double maxnmi = 0; 223 int maxj = -1; 224 TIntArrayList cluster = new TIntArrayList(clusters1[i].length); 225 for (int j = 0; j < clusters1[i].length; j++) { 226 if(inactiveRows.contains(clusters1[i][j])) 227 continue; 228 cluster.add(clusters1[i][j]); 229 } 230 int[][] correct = new int[][]{cluster.toArray()}; 231 for (int j = 0; j < clusters2.length; j++) { 232 int[][] estimated = new int[][]{clusters2[j]}; 233// NMIAnalysis nmi = new NMIClusterAnalyser().analyse(correct, estimated); 234 double score = 0; 235 if(correct[0].length == 1 && estimated[0].length == 1){ 236 // BOTH 1, either they are the same or not! 237 score = correct[0][0] == estimated[0][0] ? 1 : 0; 238 } 239 else{ 240 score = new FScoreClusterAnalyser().analyse(correct, estimated).score(); 241 } 242 if(!Double.isNaN(score)) 243 { 244 if(score > maxnmi){ 245 maxnmi = score; 246 maxj = j; 247 } 248 } 249 } 250 stability.put(i, IntDoublePair.pair(maxj, maxnmi)); 251 } 252 logger.debug(String.format("The stability is:\n%s",stability)); 253 return stability; 254 } 255 256 @Override 257 public int[][] performClustering(SparseMatrix data) { 258 return this.cluster(data).clusters(); 259 } 260 261 262}