001package ch.akuhn.matrix;
002
003import java.io.File;
004import java.io.FileWriter;
005import java.io.IOException;
006import java.io.PrintWriter;
007import java.io.StringWriter;
008import java.io.Writer;
009import java.util.Arrays;
010import java.util.Iterator;
011import java.util.NoSuchElementException;
012
013import ch.akuhn.matrix.Vector.Entry;
014
015/**
016 * Two-dimensional table of floating point numbers.
017 * <P>
018 * 
019 * @author Adrian Kuhn
020 * 
021 */
022public abstract class Matrix {
023
024        private static final int MAX_PRINT = 100;
025
026        /**
027         * Add to the value at the given row/column
028         * 
029         * @param row
030         * @param column
031         * @param value
032         * @return the new value
033         */
034        public double add(int row, int column, double value) {
035                return put(row, column, get(row, column) + value);
036        }
037
038        /**
039         * Get an {@link Iterable} over the rows
040         * 
041         * @return an {@link Iterable} over the rows
042         */
043        public Iterable<Vector> rows() {
044                return vecs(/* isRow */true);
045        }
046
047        private Iterable<Vector> vecs(final boolean isRow) {
048                return new Iterable<Vector>() {
049                        @Override
050                        public Iterator<Vector> iterator() {
051                                return new Iterator<Vector>() {
052
053                                        private int count = 0;
054
055                                        @Override
056                                        public boolean hasNext() {
057                                                return count < (isRow ? rowCount() : columnCount());
058                                        }
059
060                                        @Override
061                                        public Vector next() {
062                                                if (!hasNext())
063                                                        throw new NoSuchElementException();
064                                                return new Vec(count++, isRow);
065                                        }
066
067                                        @Override
068                                        public void remove() {
069                                                throw new UnsupportedOperationException();
070                                        }
071                                };
072                        }
073                };
074        }
075
076        /**
077         * Get an {@link Iterable} over the columns
078         * 
079         * @return an {@link Iterable} over the columns
080         */
081        public Iterable<Vector> columns() {
082                return vecs(/* isRow */false);
083        }
084
085        /**
086         * Get the number of columns
087         * 
088         * @return the number of columns
089         */
090        public abstract int columnCount();
091
092        /**
093         * Get the density
094         * 
095         * @return the density
096         */
097        public double density() {
098                return (double) used() / elementCount();
099        }
100
101        /**
102         * @return the number of elements
103         */
104        public int elementCount() {
105                return rowCount() * columnCount();
106        }
107
108        /**
109         * @param row
110         * @param column
111         * @return the value at the given row and column
112         */
113        public abstract double get(int row, int column);
114
115        /**
116         * Set the value at the given row/column
117         * 
118         * @param row
119         * @param column
120         * @param value
121         * @return the value being set
122         */
123        public abstract double put(int row, int column, double value);
124
125        /**
126         * @return the number of rows
127         */
128        public abstract int rowCount();
129
130        /**
131         * @return the number of non-zero elements
132         */
133        public abstract int used();
134
135        /**
136         * I/O
137         * 
138         * @param appendable
139         * @throws IOException
140         * @see "http://tedlab.mit.edu/~dr/svdlibc/SVD_F_ST.html"
141         */
142        public void storeSparseOn(Appendable appendable) throws IOException {
143                // this stores the transposed matrix, but as we will transpose it again
144                // when reading it, this can be done without loss of generality.
145                appendable.append(this.columnCount() + " ");
146                appendable.append(this.rowCount() + " ");
147                appendable.append(this.used() + "\r");
148                for (final Vector row : rows()) {
149                        appendable.append(row.used() + "\r");
150                        for (final Entry each : row.entries()) {
151                                appendable.append(each.index + " " + each.value + " ");
152                        }
153                        appendable.append("\r");
154                }
155        }
156
157        /**
158         * Write to file
159         * 
160         * @param fname
161         *            filename
162         * @throws IOException
163         */
164        public void storeSparseOn(String fname) throws IOException {
165                final FileWriter fw = new FileWriter(new File(fname));
166                storeSparseOn(fw);
167                fw.close();
168        }
169
170        /**
171         * Get the given row as a vector
172         * 
173         * @param row
174         * @return the row
175         */
176        public Vector row(int row) {
177                return new Vec(row, /* isRow */true);
178        }
179
180        /**
181         * Get the given column as a vector
182         * 
183         * @param column
184         * @return the column
185         */
186        public Vector column(int column) {
187                return new Vec(column, /* isRow */false);
188        }
189
190        /**
191         * Get the matrix data as a 2D dense array
192         * 
193         * @return the array representation
194         */
195        public double[][] asArray() {
196                final double[][] result = new double[rowCount()][columnCount()];
197                for (int x = 0; x < result.length; x++) {
198                        for (int y = 0; y < result[x].length; y++) {
199                                result[x][y] = get(x, y);
200                        }
201                }
202                return result;
203        }
204
205        /**
206         * Get the index of the given vector
207         * 
208         * @param vec
209         * @return the index
210         */
211        public static int indexOf(Vector vec) {
212                return ((Vec) vec).index0;
213        }
214
215        private class Vec extends Vector {
216
217                int index0;
218                private boolean isRow;
219
220                Vec(int n, boolean isRow) {
221                        this.isRow = isRow;
222                        this.index0 = n;
223                }
224
225                @Override
226                public int size() {
227                        return isRow ? columnCount() : rowCount();
228                }
229
230                @Override
231                public double put(int index, double value) {
232                        return isRow ? Matrix.this.put(this.index0, index, value)
233                                        : Matrix.this.put(index, this.index0, value);
234                }
235
236                @Override
237                public double get(int index) {
238                        return isRow ? Matrix.this.get(this.index0, index)
239                                        : Matrix.this.get(index, this.index0);
240                }
241
242                @Override
243                public boolean equals(Vector v, double epsilon) {
244                        throw new Error("Not yet implemented");
245                }
246
247                @Override
248                public Vector times(double scalar) {
249                        throw new Error("Not yet implemented");
250                }
251
252                @Override
253                public Vector timesEquals(double scalar) {
254                        throw new Error("Not yet implemented");
255                }
256        }
257
258        /**
259         * Returns <code>y = Ax</code>.
260         * 
261         * @param x
262         * @return the result
263         * 
264         */
265        public Vector mult(Vector x) {
266                assert x.size() == this.columnCount();
267                final Vector y = Vector.dense(this.rowCount());
268                int i = 0;
269                for (final Vector row : rows())
270                        y.put(i++, row.dot(x));
271                return y;
272        }
273
274        /**
275         * Returns <code>y = (A^T)x</code>.
276         * 
277         * @param x
278         * @return the result
279         */
280        public Vector transposeMultiply(Vector x) {
281                assert x.size() == this.rowCount();
282                final Vector y = Vector.dense(this.columnCount());
283                int i = 0;
284                for (final Vector row : rows())
285                        row.scaleAndAddTo(x.get(i++), y);
286                return y;
287        }
288
289        /**
290         * Returns <code>y = (A^T)Ax</code>.
291         * <P>
292         * Useful for doing singular decomposition using ARPACK's dsaupd routine.
293         * 
294         * @param x
295         * @return the result
296         */
297        public Vector transposeNonTransposeMultiply(Vector x) {
298                return this.transposeMultiply(this.mult(x));
299        }
300
301        /**
302         * Build a matrix from the given values (row-major)
303         * 
304         * @param n
305         * @param m
306         * @param values
307         * @return the matrix
308         */
309        public static Matrix from(int n, int m, double... values) {
310                assert n * m == values.length;
311                final double[][] data = new double[n][];
312                for (int i = 0; i < n; i++)
313                        data[i] = Arrays.copyOfRange(values, i * m, (i + 1) * m);
314                return new DenseMatrix(data);
315        }
316
317        /**
318         * Create a zeroed dense matrix
319         * 
320         * @param n
321         * @param m
322         * @return the matrix
323         */
324        public static Matrix dense(int n, int m) {
325                return new DenseMatrix(n, m);
326        }
327
328        /**
329         * @return true of matrix is square; false otherwise
330         */
331        public boolean isSquare() {
332                return columnCount() == rowCount();
333        }
334
335        /**
336         * Get in col-major format
337         * 
338         * @return the data in column major format
339         */
340        public double[] asColumnMajorArray() {
341                final double[] data = new double[columnCount() * rowCount()];
342                final int n = columnCount();
343                int i = 0;
344                for (final Vector row : rows()) {
345                        for (final Entry each : row.entries()) {
346                                data[i + each.index * n] = each.value;
347                        }
348                        i++;
349                }
350                return data;
351        }
352
353        /**
354         * Create a sparse matrix
355         * 
356         * @param n
357         * @param m
358         * @return new sparse matrix
359         */
360        public static SparseMatrix sparse(int n, int m) {
361                return new SparseMatrix(n, m);
362        }
363
364        /**
365         * @return max value in matrix
366         */
367        public double max() {
368                return Util.max(this.unwrap(), Double.NaN);
369        }
370
371        /**
372         * @return min value in matrix
373         */
374        public double min() {
375                return Util.min(this.unwrap(), Double.NaN);
376        }
377
378        /**
379         * @return mean value of matrix
380         */
381        public double mean() {
382                final double[][] values = unwrap();
383                return Util.sum(values) / Util.count(values);
384        }
385
386        /**
387         * @return unwrapped matrix
388         */
389        public double[][] unwrap() {
390                throw new IllegalStateException("cannot unwrap instance of " + this.getClass().getSimpleName());
391        }
392
393        /**
394         * @return mean of each row
395         */
396        public double[] rowwiseMean() {
397                final double[] mean = new double[rowCount()];
398                int i = 0;
399                for (final Vector row : rows())
400                        mean[i++] = row.mean();
401                return mean;
402        }
403
404        /**
405         * @return the histogram
406         */
407        public int[] getHistogram() {
408                return Util.getHistogram(this.unwrap(), 100);
409        }
410
411        /**
412         * @return an empty instance of this matrix type
413         */
414        public Matrix newInstance() {
415                return newInstance(rowCount(), columnCount());
416        }
417
418        /**
419         * @param rows
420         * @param cols
421         * @return an empty instance of this matrix type
422         */
423        public abstract Matrix newInstance(int rows, int cols);
424
425        @Override
426        public String toString() {
427                final Writer sw = new StringWriter();
428                final PrintWriter writer = new PrintWriter(sw);
429                writer.println("NRows = " + rowCount());
430                writer.println("NCols = " + columnCount());
431                final int maxPrint = Math.min(rowCount() * columnCount(), MAX_PRINT);
432                int i;
433                for (i = 0; i < maxPrint; i++) {
434                        final int row = i / columnCount();
435                        final int col = i - (row * columnCount());
436                        writer.printf("%d\t%d\t%2.5f\n", row, col, this.get(row, col));
437                }
438                if (i < rowCount() * columnCount() - 1) {
439                        writer.printf("...");
440                }
441                writer.flush();
442                return sw.toString();
443        }
444}