001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.regression;
031
032import java.util.Arrays;
033import java.util.List;
034
035import no.uib.cipr.matrix.NotConvergedException;
036
037import org.openimaj.math.matrix.MatrixUtils;
038import org.openimaj.math.model.EstimatableModel;
039import org.openimaj.util.pair.IndependentPair;
040
041import Jama.Matrix;
042
043/**
044 * Given a set of independant variables a linear regressions finds the optimal
045 * vector B such that: (Y - Xb)^2 = 0 (Y - Xb)^{T}(Y-Xb) = 0
046 * 
047 * calculated by assuming a convex shape of (Y - Xb) with varying values of b
048 * (reasonable as the function is linear) and then calculating the point at
049 * which the first derivative of this function is 0. i.e.:
050 * 
051 * d/db (y - Xb)^{T} (y - Xb) = -X^{T}(y - Xb) - X^{T}(y - Xb) = - 2 * X^{T}(y -
052 * Xb)
053 * 
054 * which at the 0 is: - 2 * X^{T}(y - Xb) = 0 X^{T}(y - Xb) X^{T}y - X^{T}Xb = 0
055 * X^{T}y = X^{T}Xb b = (X^{T} X)^{-1} X^{T} y
056 * 
057 * Calculating this function directly behaves numerically badly when X is
058 * extremely skinny and tall (i.e. lots of data, fewer dimentions) so we
059 * calculate this using the SVD, using the SVD we can decompose X as:
060 * 
061 * X = UDV^{T}
062 * 
063 * s.t. U and V are orthonormal from this we can calculate: b = V D^{-1} U^{T} y
064 * 
065 * which is equivilant but more numerically stable.
066 * 
067 * Note that upon input any vector of independent variables x_n are
068 * automatically to turned into an n + 1 vector {1,x0,x1,...,xn} which handles
069 * the constant values added to y
070 * 
071 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
072 * 
073 */
074public class LinearRegression implements EstimatableModel<double[], double[]> {
075
076        private Matrix weights;
077
078        /**
079         * linear regression model
080         */
081        public LinearRegression() {
082        }
083
084        @Override
085        public boolean estimate(List<? extends IndependentPair<double[], double[]>> data) {
086                if (data.size() == 0)
087                        return false;
088
089                final int correctedx = data.get(0).firstObject().length + 1;
090                final int correctedy = data.get(0).secondObject().length;
091                final double[][] y = new double[data.size()][correctedy];
092                final double[][] x = new double[data.size()][correctedx];
093
094                int i = 0;
095                for (final IndependentPair<double[], double[]> item : data) {
096                        y[i] = item.secondObject();
097                        x[i][0] = 1;
098                        System.arraycopy(item.firstObject(), 0, x[i], 1, item.firstObject().length);
099                        i += 1;
100                }
101
102                estimate_internal(new Matrix(y), new Matrix(x));
103
104                return true;
105        }
106
107        /**
108         * As in {@link #estimate(List)} but using double arrays for efficiency.
109         * 
110         * @param yd
111         * @param xd
112         */
113        public void estimate(double[][] yd, double[][] xd) {
114                final double[][] x = appendConstant(xd);
115                estimate_internal(new Matrix(yd), new Matrix(x));
116        }
117
118        private double[][] appendConstant(double[][] xd) {
119                final int corrected = xd[0].length + 1;
120                final double[][] x = new double[xd.length][corrected];
121
122                for (int i = 0; i < xd.length; i++) {
123                        x[i][0] = 1;
124                        System.arraycopy(xd[i], 0, x[i], 1, xd[i].length);
125                }
126                return x;
127        }
128
129        /**
130         * As in {@link #estimate(List)} but using double arrays for efficiency.
131         * Estimates: b = V D^{-1} U^{T} y s.t. X = UDV^{T}
132         * 
133         * @param y
134         * @param x
135         */
136        public void estimate(Matrix y, Matrix x) {
137                estimate(y.getArray(), x.getArray());
138        }
139
140        private void estimate_internal(Matrix y, Matrix x) {
141                try {
142                        final no.uib.cipr.matrix.DenseMatrix mjtX = new no.uib.cipr.matrix.DenseMatrix(x.getArray());
143                        no.uib.cipr.matrix.SVD svd;
144                        svd = no.uib.cipr.matrix.SVD.factorize(mjtX);
145                        final Matrix u = MatrixUtils.convert(svd.getU(), svd.getU().numRows(), svd.getS().length);
146                        final Matrix v = MatrixUtils.convert(svd.getVt(), svd.getS().length, svd.getVt().numColumns()).transpose();
147                        final Matrix d = MatrixUtils.diag(svd.getS());
148
149                        weights = v.times(MatrixUtils.pseudoInverse(d)).times(u.transpose()).times(y);
150                } catch (final NotConvergedException e) {
151                        throw new RuntimeException(e.getMessage());
152                }
153
154        }
155
156        @Override
157        public double[] predict(double[] data) {
158                final double[][] corrected = new double[][] { new double[data.length + 1] };
159                corrected[0][0] = 1;
160                System.arraycopy(data, 0, corrected[0], 1, data.length);
161                final Matrix x = new Matrix(corrected);
162
163                return x.times(this.weights).transpose().getArray()[0];
164        }
165
166        /**
167         * Helper function which adds the constant component to x and returns
168         * predicted values for y, one per row
169         * 
170         * @param x
171         * @return predicted y
172         */
173        public Matrix predict(Matrix x) {
174                x = new Matrix(appendConstant(x.getArray()));
175                return x.times(this.weights);
176        }
177
178        @Override
179        public int numItemsToEstimate() {
180                return 2;
181        }
182
183        @Override
184        public LinearRegression clone() {
185                return new LinearRegression();
186        }
187
188        @Override
189        public boolean equals(Object obj) {
190                if ((!(obj instanceof LinearRegression)))
191                        return false;
192                final LinearRegression that = (LinearRegression) obj;
193                final double[][] thatw = that.weights.getArray();
194                final double[][] thisw = this.weights.getArray();
195                for (int i = 0; i < thisw.length; i++) {
196                        if (!Arrays.equals(thatw[i], thisw[i]))
197                                return false;
198                }
199                return true;
200        }
201
202        @Override
203        public String toString() {
204                return "LinearRegression with coefficients: " + Arrays.toString(this.weights.transpose().getArray()[0]);
205        }
206
207}