001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.matrix.similarity;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.util.Arrays;
037import java.util.Scanner;
038
039import org.jgrapht.UndirectedGraph;
040import org.jgrapht.graph.DefaultEdge;
041import org.jgrapht.graph.SimpleGraph;
042import org.openimaj.io.ReadWriteable;
043import org.openimaj.math.matrix.ReadWriteableMatrix;
044import org.openimaj.math.matrix.similarity.processor.SimilarityMatrixProcessor;
045
046import Jama.Matrix;
047
048/**
049 * A similarity matrix is a square matrix with an associated index.
050 * It can be used to store all the similarities across a set
051 * of objects.
052 * 
053 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
054 */
055public class SimilarityMatrix extends ReadWriteableMatrix implements ReadWriteable {
056        private static final long serialVersionUID = 1L;
057
058        protected String[] index;
059
060        /**
061         * Construct an empty similarity matrix. Only for IOUtils use.
062         */
063        protected SimilarityMatrix() {
064                super();
065        }
066
067        /**
068         * Construct a similarity matrix with the given size
069         * and allocate the index accordingly.
070         * @param size the size of the matrix
071         */
072        public SimilarityMatrix(int size) {
073                super(size, size);
074                index = new String[size];
075        }
076
077        /**
078         * Construct a similarity matrix with the given index
079         * and set the matrix size based on the index length.
080         * @param index the index.
081         */
082        public SimilarityMatrix(String [] index) {
083                super(index.length, index.length);
084                this.index = index;
085        }
086
087        /**
088         * Construct a similarity matrix based on the given index
089         * and matrix. The matrix must be square and its dimensions
090         * must be the same as the index length.
091         * 
092         * @param index the index
093         * @param data the matrix
094         */
095        public SimilarityMatrix(String [] index, Matrix data) {
096                super(data);
097
098                if (data.getColumnDimension() != data.getRowDimension())
099                        throw new IllegalArgumentException("matrix must be square");
100
101                if (index.length != data.getRowDimension())
102                        throw new IllegalArgumentException("index must have same length as matrix sides");
103
104                this.index = index;
105        }
106        
107        /**
108         * Construct a similarity matrix based on the given index
109         * and matrix data. The matrix data must be square and its dimensions
110         * must be the same as the index length.
111         * 
112         * @param index the index
113         * @param data the matrix data
114         */
115        public SimilarityMatrix(String [] index, double[][] data) {
116                super(data);
117
118                if (index.length != this.getRowDimension())
119                        throw new IllegalArgumentException("index must have same length as matrix sides");
120
121                this.index = index;
122        }
123
124        /**
125         * Get the offset in the index for a given value 
126         * @param value the value
127         * @return the index
128         */
129        public int indexOf(String value) {
130                return Arrays.binarySearch(index, value);
131        }
132
133        /**
134         * Set the value of the index at a given offset
135         * @param i the offset
136         * @param value the value
137         */
138        public void setIndexValue(int i, String value) {
139                index[i] = value;
140        }
141
142        /**
143         * Get a value from the index
144         * @param i the offset into the index
145         * @return the value
146         */
147        public String getIndexValue(int i) {
148                return index[i];
149        }
150
151        /**
152         * Get the index
153         * @return the index
154         */
155        public String [] getIndex() {
156                return index;
157        }
158
159        @Override
160        public void readASCII(Scanner in) throws IOException {
161                super.readASCII(in);
162
163                index = new String[this.getRowDimension()];
164
165                for (int i=0; i<index.length; i++)
166                        index[i] = in.nextLine();
167        }
168
169        @Override
170        public String asciiHeader() {
171                return this.getClass().getName() + " ";
172        }
173
174        @Override
175        public void readBinary(DataInput in) throws IOException {
176                super.readBinary(in);
177
178                index = new String[this.getRowDimension()];
179
180                for (int i=0; i<index.length; i++)
181                        index[i] = in.readUTF();
182        }
183
184        @Override
185        public byte[] binaryHeader() {
186                return "SimMat".getBytes();
187        }
188
189        @Override
190        public void writeASCII(PrintWriter out) throws IOException {
191                super.writeASCII(out);
192
193                for (String s : index)
194                        out.println(s);
195        }
196
197        @Override
198        public void writeBinary(DataOutput out) throws IOException {
199                super.writeBinary(out);
200
201                for (String s : index)
202                        out.writeUTF(s);
203        }
204
205        /**
206         * Convert the similarity matrix to an unweighted, undirected
207         * graph representation. A threshold is used to determine
208         * if edges should be created. If the value at [r][c] is bigger
209         * than the threshold, then an edge will be created between the
210         * vertices represented by index[r] and index[c].
211         * 
212         * @param threshold the threshold
213         * @return the graph
214         */
215        public UndirectedGraph<String, DefaultEdge> toUndirectedUnweightedGraph(double threshold) {
216                UndirectedGraph<String, DefaultEdge> graph = new SimpleGraph<String, DefaultEdge>(DefaultEdge.class);
217
218                final int rows = this.getRowDimension();
219                final int cols = this.getColumnDimension();
220                final double[][] data = this.getArray(); 
221
222                for (String s : index) {
223                        graph.addVertex(s);
224                }
225
226                for (int r=0; r<rows; r++) {
227                        for (int c=0; c<cols; c++) {
228                                if (r != c && data[r][c] > threshold)
229                                        graph.addEdge(index[r], index[c]);
230                        }
231                }
232
233                return graph;
234        }
235
236        @Override   
237        public SimilarityMatrix copy() {
238                double[][] C = this.getArrayCopy();
239                String[] i = Arrays.copyOf(index, index.length);
240                
241                return new SimilarityMatrix(i, C);
242        }
243        
244        @Override   
245        public SimilarityMatrix clone() {
246                return copy();
247        }
248
249        /**
250         * Process a copy of this similarity matrix with the
251         * given processor and return the copy.
252         * 
253         * @param proc the processor
254         * @return a processed copy of this matrix
255         */
256        public SimilarityMatrix process(SimilarityMatrixProcessor proc) {
257                SimilarityMatrix mat = this.clone();
258                proc.process(mat);
259                return mat;
260        }
261
262        /**
263         * Process this matrix with the given processor.
264         * @param proc the processor
265         * @return this.
266         */
267        public SimilarityMatrix processInplace(SimilarityMatrixProcessor proc) {
268                proc.process(this);
269                return this;
270        }
271
272        @Override
273        public String toString() {
274                StringBuilder sb = new StringBuilder();
275
276                int maxIndexLength = 0;
277                for (String s : index) 
278                        if (s.length() > maxIndexLength) 
279                                maxIndexLength = s.length();
280
281                final int maxIndexCountLength = (index.length + "").length();
282                final String indexFormatString = "%"+(maxIndexCountLength+2)+"s %" + maxIndexLength + "s ";  
283
284                final int rows = this.getRowDimension();
285                final int cols = this.getColumnDimension();
286                final double[][] data = this.getArray(); 
287
288                sb.append(String.format("%"+(maxIndexLength+maxIndexCountLength+3)+"s", ""));
289                for (int r=0; r<rows; r++) {
290                        sb.append(String.format("%9s", String.format("(%d)", r)));
291                }
292                sb.append("\n");
293
294                for (int r=0; r<rows; r++) {
295                        sb.append(String.format(indexFormatString, String.format("(%d)", r), index[r]));
296
297                        for (int c=0; c<cols; c++) {
298                                sb.append(String.format("%8.3f ", data[r][c]));
299                        }
300                        sb.append("\n");
301                }
302
303                return sb.toString();
304        }
305}