001package ch.akuhn.matrix;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.List;
006import java.util.Random;
007import java.util.Scanner;
008
009import ch.akuhn.matrix.Vector.Entry;
010
011/**
012 * A sparse matrix
013 * 
014 * @author Adrian Kuhn
015 */
016public class SparseMatrix extends Matrix {
017
018        private int columns;
019
020        private List<Vector> rows;
021
022        /**
023         * Construct with the given values
024         * 
025         * @param values
026         */
027        public SparseMatrix(double[][] values) {
028                this.columns = values[0].length;
029                this.rows = new ArrayList<Vector>(values.length);
030                for (final double[] each : values)
031                        addRow(each);
032        }
033
034        /**
035         * Construct with the given size
036         * 
037         * @param rows
038         * @param columns
039         */
040        public SparseMatrix(int rows, int columns) {
041                this.columns = columns;
042                this.rows = new ArrayList<Vector>(rows);
043                for (int times = 0; times < rows; times++)
044                        addRow();
045        }
046
047        @Override
048        public double add(int row, int column, double sum) {
049                return rows.get(row).add(column, sum);
050        }
051
052        /**
053         * Add a new column to the end, increasing the number of columns by 1
054         * 
055         * @return number of cols BEFORE new one was added
056         */
057        public int addColumn() {
058                columns++;
059                for (final Vector each : rows)
060                        ((SparseVector) each).resizeTo(columns);
061                return columns - 1;
062        }
063
064        /**
065         * Add a new row to the end, increasing the number of rows by 1
066         * 
067         * @return number of rows BEFORE new one was added
068         */
069        public int addRow() {
070                rows.add(new SparseVector(columns));
071                return rowCount() - 1;
072        }
073
074        protected int addRow(double[] values) {
075                rows.add(new SparseVector(values));
076                return rowCount() - 1;
077        }
078
079        /**
080         * Add the given values to the given row
081         * 
082         * @param row
083         * @param values
084         */
085        public void addToRow(int row, Vector values) {
086                final Vector v = rows.get(row);
087                for (final Entry each : values.entries())
088                        v.add(each.index, each.value);
089        }
090
091        /**
092         * Convert to a dense 2d double array
093         * 
094         * @return 2d double array
095         */
096        public double[][] asDenseDoubleDouble() {
097                final double[][] dense = new double[rowCount()][columnCount()];
098
099                for (int ri = 0; ri < rows.size(); ri++) {
100                        final Vector row = rows.get(ri);
101
102                        for (final Entry column : row.entries()) {
103                                dense[ri][column.index] = column.value;
104                        }
105                }
106                return dense;
107        }
108
109        @Override
110        public int columnCount() {
111                return columns;
112        }
113
114        @Override
115        public boolean equals(Object obj) {
116                return obj instanceof SparseMatrix && rows.equals(((SparseMatrix) obj).rows);
117        }
118
119        @Override
120        public double get(int row, int column) {
121                return rows.get(row).get(column);
122        }
123
124        @Override
125        public int hashCode() {
126                return rows.hashCode();
127        }
128
129        @Override
130        public double put(int row, int column, double value) {
131                return rows.get(row).put(column, value);
132        }
133
134        @Override
135        public Iterable<Vector> rows() {
136                return Collections.unmodifiableCollection(rows);
137        }
138
139        @Override
140        public Vector row(int row) {
141                return this.rows.get(row);
142        }
143
144        @Override
145        public int rowCount() {
146                return rows.size();
147        }
148
149        /**
150         * Sets the row, no check is made on {@link SparseVector#size()} Use with
151         * care.
152         * 
153         * @param row
154         * @param values
155         */
156        public void setRow(int row, SparseVector values) {
157                rows.set(row, values);
158        }
159
160        @Override
161        public int used() {
162                int used = 0;
163                for (final Vector each : rows)
164                        used += each.used();
165                return used;
166        }
167
168        /**
169         * Trim each row
170         */
171        public void trim() {
172                for (final Vector each : rows) {
173                        ((SparseVector) each).trim();
174                }
175        }
176
177        /**
178         * Read matrix from {@link Scanner}
179         * 
180         * @param scan
181         * @return the matrix
182         */
183        public static SparseMatrix readFrom(Scanner scan) {
184                final int columns = scan.nextInt();
185                final int rows = scan.nextInt();
186                final int used = scan.nextInt();
187                final SparseMatrix matrix = new SparseMatrix(rows, columns);
188                for (int row = 0; row < rows; row++) {
189                        final int len = scan.nextInt();
190                        for (int i = 0; i < len; i++) {
191                                final int column = scan.nextInt();
192                                final double value = scan.nextDouble();
193                                matrix.put(row, column, value);
194                        }
195                }
196                assert matrix.used() == used;
197                return matrix;
198        }
199
200        /**
201         * Create a random matrix
202         * 
203         * @param n
204         * @param m
205         * @param density
206         * @return the matrix
207         */
208        public static SparseMatrix random(int n, int m, double density) {
209                final Random random = new Random();
210                final SparseMatrix A = new SparseMatrix(n, m);
211                for (int i = 0; i < n; i++) {
212                        for (int j = 0; j < m; j++) {
213                                if (random.nextDouble() > density)
214                                        continue;
215                                A.put(i, j, random.nextDouble());
216                        }
217                }
218                return A;
219        }
220
221        @Override
222        public Vector mult(Vector dense) {
223                assert dense.size() == this.columnCount();
224                final double[] y = new double[this.rowCount()];
225                final double[] x = ((DenseVector) dense).values;
226                for (int i = 0; i < y.length; i++) {
227                        final SparseVector row = (SparseVector) rows.get(i);
228                        double sum = 0;
229                        for (int k = 0; k < row.used; k++) {
230                                sum += x[row.keys[k]] * row.values[k];
231                        }
232                        y[i] = sum;
233                }
234                return Vector.wrap(y);
235        }
236
237        @Override
238        public Vector transposeMultiply(Vector dense) {
239                assert dense.size() == this.rowCount();
240                final double[] y = new double[this.columnCount()];
241                final double[] x = ((DenseVector) dense).values;
242                for (int i = 0; i < x.length; i++) {
243                        final SparseVector row = (SparseVector) rows.get(i);
244                        for (int k = 0; k < row.used; k++) {
245                                y[row.keys[k]] += x[i] * row.values[k];
246                        }
247                }
248                return Vector.wrap(y);
249        }
250
251        @Override
252        public Matrix newInstance(int rows, int cols) {
253                return new SparseMatrix(rows, cols);
254        }
255}