001package ch.akuhn.matrix;
002
003import java.util.Arrays;
004
005/**
006 * Matrix where <CODE>a<SUB>ij</SUB> = a<SUB>ji</SUB></CODE> for all elements.
007 * <P>
008 * 
009 * @author Adrian Kuhn
010 * 
011 */
012public class SymmetricMatrix extends DenseMatrix {
013
014        /**
015         * Construct with given size
016         * 
017         * @param size
018         */
019        public SymmetricMatrix(int size) {
020                super(size, size);
021        }
022
023        /**
024         * Construct with given values, which must be jagged and represent the lower
025         * triangular values
026         * 
027         * @param values
028         */
029        public SymmetricMatrix(double[][] values) {
030                super(values);
031        }
032
033        @Override
034        protected void assertInvariant() throws IllegalArgumentException {
035                for (int n = 0; n < values.length; n++) {
036                        if (values[n].length != (n + 1))
037                                throw new IllegalArgumentException();
038                }
039        }
040
041        @Override
042        protected double[][] makeValues(int rows, int columns) {
043                assert rows == columns;
044                final double[][] values = new double[rows][];
045                for (int n = 0; n < values.length; n++)
046                        values[n] = new double[n + 1];
047                return values;
048        }
049
050        @Override
051        public int columnCount() {
052                return rowCount();
053        }
054
055        @Override
056        public double get(int row, int column) {
057                return row > column ? values[row][column] : values[column][row];
058        }
059
060        @Override
061        public double put(int row, int column, double value) {
062                return row > column ? (values[row][column] = value) : (values[column][row] = value);
063        }
064
065        @Override
066        public int rowCount() {
067                return values.length;
068        }
069
070        /**
071         * Create from a square matrix
072         * 
073         * @param square
074         * @return the matrix
075         */
076        public static DenseMatrix fromSquare(double[][] square) {
077                final double[][] jagged = new double[square.length][];
078                for (int i = 0; i < jagged.length; i++) {
079                        assert square[i].length == square.length;
080                        jagged[i] = Arrays.copyOf(square[i], i + 1);
081                }
082                return new SymmetricMatrix(jagged);
083        }
084
085        /**
086         * Create from jagged low triangular values
087         * 
088         * @param values
089         * @return the matrix
090         */
091        public static DenseMatrix fromJagged(double[][] values) {
092                return new SymmetricMatrix(values);
093        }
094
095        @Override
096        public double[][] unwrap() {
097                return values;
098        }
099
100        @Override
101        public double[] rowwiseMean() {
102                final double[] mean = new double[rowCount()];
103                for (int i = 0; i < values.length; i++) {
104                        for (int j = 0; j < i; j++) {
105                                mean[i] += values[i][j];
106                                mean[j] += values[i][j];
107                        }
108                }
109                for (int n = 0; n < mean.length; n++)
110                        mean[n] /= mean.length;
111                return mean;
112        }
113
114        @Override
115        public Vector mult(Vector v) {
116                assert v.size() == values.length;
117                final double[] mult = new double[v.size()];
118                for (int i = 0; i < values.length; i++) {
119                        for (int j = 0; j < i; j++) {
120                                mult[i] += values[i][j] * v.get(j);
121                                mult[j] += values[i][j] * v.get(i);
122                        }
123                }
124                return Vector.wrap(mult);
125        }
126
127}