001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.matrix.similarity; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035import java.io.PrintWriter; 036import java.util.Arrays; 037import java.util.Scanner; 038 039import org.jgrapht.UndirectedGraph; 040import org.jgrapht.graph.DefaultEdge; 041import org.jgrapht.graph.SimpleGraph; 042import org.openimaj.io.ReadWriteable; 043import org.openimaj.math.matrix.ReadWriteableMatrix; 044import org.openimaj.math.matrix.similarity.processor.SimilarityMatrixProcessor; 045 046import Jama.Matrix; 047 048/** 049 * A similarity matrix is a square matrix with an associated index. 050 * It can be used to store all the similarities across a set 051 * of objects. 052 * 053 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 054 */ 055public class SimilarityMatrix extends ReadWriteableMatrix implements ReadWriteable { 056 private static final long serialVersionUID = 1L; 057 058 protected String[] index; 059 060 /** 061 * Construct an empty similarity matrix. Only for IOUtils use. 062 */ 063 protected SimilarityMatrix() { 064 super(); 065 } 066 067 /** 068 * Construct a similarity matrix with the given size 069 * and allocate the index accordingly. 070 * @param size the size of the matrix 071 */ 072 public SimilarityMatrix(int size) { 073 super(size, size); 074 index = new String[size]; 075 } 076 077 /** 078 * Construct a similarity matrix with the given index 079 * and set the matrix size based on the index length. 080 * @param index the index. 081 */ 082 public SimilarityMatrix(String [] index) { 083 super(index.length, index.length); 084 this.index = index; 085 } 086 087 /** 088 * Construct a similarity matrix based on the given index 089 * and matrix. The matrix must be square and its dimensions 090 * must be the same as the index length. 091 * 092 * @param index the index 093 * @param data the matrix 094 */ 095 public SimilarityMatrix(String [] index, Matrix data) { 096 super(data); 097 098 if (data.getColumnDimension() != data.getRowDimension()) 099 throw new IllegalArgumentException("matrix must be square"); 100 101 if (index.length != data.getRowDimension()) 102 throw new IllegalArgumentException("index must have same length as matrix sides"); 103 104 this.index = index; 105 } 106 107 /** 108 * Construct a similarity matrix based on the given index 109 * and matrix data. The matrix data must be square and its dimensions 110 * must be the same as the index length. 111 * 112 * @param index the index 113 * @param data the matrix data 114 */ 115 public SimilarityMatrix(String [] index, double[][] data) { 116 super(data); 117 118 if (index.length != this.getRowDimension()) 119 throw new IllegalArgumentException("index must have same length as matrix sides"); 120 121 this.index = index; 122 } 123 124 /** 125 * Get the offset in the index for a given value 126 * @param value the value 127 * @return the index 128 */ 129 public int indexOf(String value) { 130 return Arrays.binarySearch(index, value); 131 } 132 133 /** 134 * Set the value of the index at a given offset 135 * @param i the offset 136 * @param value the value 137 */ 138 public void setIndexValue(int i, String value) { 139 index[i] = value; 140 } 141 142 /** 143 * Get a value from the index 144 * @param i the offset into the index 145 * @return the value 146 */ 147 public String getIndexValue(int i) { 148 return index[i]; 149 } 150 151 /** 152 * Get the index 153 * @return the index 154 */ 155 public String [] getIndex() { 156 return index; 157 } 158 159 @Override 160 public void readASCII(Scanner in) throws IOException { 161 super.readASCII(in); 162 163 index = new String[this.getRowDimension()]; 164 165 for (int i=0; i<index.length; i++) 166 index[i] = in.nextLine(); 167 } 168 169 @Override 170 public String asciiHeader() { 171 return this.getClass().getName() + " "; 172 } 173 174 @Override 175 public void readBinary(DataInput in) throws IOException { 176 super.readBinary(in); 177 178 index = new String[this.getRowDimension()]; 179 180 for (int i=0; i<index.length; i++) 181 index[i] = in.readUTF(); 182 } 183 184 @Override 185 public byte[] binaryHeader() { 186 return "SimMat".getBytes(); 187 } 188 189 @Override 190 public void writeASCII(PrintWriter out) throws IOException { 191 super.writeASCII(out); 192 193 for (String s : index) 194 out.println(s); 195 } 196 197 @Override 198 public void writeBinary(DataOutput out) throws IOException { 199 super.writeBinary(out); 200 201 for (String s : index) 202 out.writeUTF(s); 203 } 204 205 /** 206 * Convert the similarity matrix to an unweighted, undirected 207 * graph representation. A threshold is used to determine 208 * if edges should be created. If the value at [r][c] is bigger 209 * than the threshold, then an edge will be created between the 210 * vertices represented by index[r] and index[c]. 211 * 212 * @param threshold the threshold 213 * @return the graph 214 */ 215 public UndirectedGraph<String, DefaultEdge> toUndirectedUnweightedGraph(double threshold) { 216 UndirectedGraph<String, DefaultEdge> graph = new SimpleGraph<String, DefaultEdge>(DefaultEdge.class); 217 218 final int rows = this.getRowDimension(); 219 final int cols = this.getColumnDimension(); 220 final double[][] data = this.getArray(); 221 222 for (String s : index) { 223 graph.addVertex(s); 224 } 225 226 for (int r=0; r<rows; r++) { 227 for (int c=0; c<cols; c++) { 228 if (r != c && data[r][c] > threshold) 229 graph.addEdge(index[r], index[c]); 230 } 231 } 232 233 return graph; 234 } 235 236 @Override 237 public SimilarityMatrix copy() { 238 double[][] C = this.getArrayCopy(); 239 String[] i = Arrays.copyOf(index, index.length); 240 241 return new SimilarityMatrix(i, C); 242 } 243 244 @Override 245 public SimilarityMatrix clone() { 246 return copy(); 247 } 248 249 /** 250 * Process a copy of this similarity matrix with the 251 * given processor and return the copy. 252 * 253 * @param proc the processor 254 * @return a processed copy of this matrix 255 */ 256 public SimilarityMatrix process(SimilarityMatrixProcessor proc) { 257 SimilarityMatrix mat = this.clone(); 258 proc.process(mat); 259 return mat; 260 } 261 262 /** 263 * Process this matrix with the given processor. 264 * @param proc the processor 265 * @return this. 266 */ 267 public SimilarityMatrix processInplace(SimilarityMatrixProcessor proc) { 268 proc.process(this); 269 return this; 270 } 271 272 @Override 273 public String toString() { 274 StringBuilder sb = new StringBuilder(); 275 276 int maxIndexLength = 0; 277 for (String s : index) 278 if (s.length() > maxIndexLength) 279 maxIndexLength = s.length(); 280 281 final int maxIndexCountLength = (index.length + "").length(); 282 final String indexFormatString = "%"+(maxIndexCountLength+2)+"s %" + maxIndexLength + "s "; 283 284 final int rows = this.getRowDimension(); 285 final int cols = this.getColumnDimension(); 286 final double[][] data = this.getArray(); 287 288 sb.append(String.format("%"+(maxIndexLength+maxIndexCountLength+3)+"s", "")); 289 for (int r=0; r<rows; r++) { 290 sb.append(String.format("%9s", String.format("(%d)", r))); 291 } 292 sb.append("\n"); 293 294 for (int r=0; r<rows; r++) { 295 sb.append(String.format(indexFormatString, String.format("(%d)", r), index[r])); 296 297 for (int c=0; c<cols; c++) { 298 sb.append(String.format("%8.3f ", data[r][c])); 299 } 300 sb.append("\n"); 301 } 302 303 return sb.toString(); 304 } 305}