001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress;
031
032import java.util.Random;
033
034import Jama.Matrix;
035
036public class GD_SVD {
037        private static final int maxEpochs = 300;
038        private static final double initialLearningRate = 0.01;
039        private static final double annealingRate = maxEpochs * 0.1;
040
041        Matrix UprimeM;
042        Matrix VprimeM;
043        private Matrix UM;
044        private Matrix VM;
045        private Matrix SM;
046
047        public GD_SVD(Matrix MM, int maxOrder) {
048                final Random random = new Random(0);
049                final double initValue = 1 / Math.sqrt(maxOrder);
050
051                final int m = MM.getRowDimension();
052                final int n = MM.getColumnDimension();
053
054                UprimeM = new Matrix(m, maxOrder);
055                VprimeM = new Matrix(n, maxOrder);
056                final double[][] Uprime = UprimeM.getArray();
057                final double[][] Vprime = VprimeM.getArray();
058                final double[][] M = MM.getArray();
059
060                for (int k = 0; k < maxOrder; k++) {
061                        for (int i = 0; i < m; i++)
062                                Uprime[i][k] = random.nextGaussian() * initValue;
063                        for (int j = 0; j < n; j++)
064                                Vprime[j][k] = random.nextGaussian() * initValue;
065
066                        double lastError = Double.MAX_VALUE;
067                        for (int epoch = 0; epoch < maxEpochs; epoch++) {
068                                final double learningRate = initialLearningRate / (1 + epoch / annealingRate);
069
070                                double sq = 0;
071                                for (int i = 0; i < m; i++) {
072                                        for (int j = 0; j < n; j++) {
073                                                double pred = 0;
074                                                for (int kk = 0; kk <= k; kk++)
075                                                        pred += Uprime[i][kk] * Vprime[j][kk];
076
077                                                final double error = M[i][j] - pred;
078                                                System.out.println("Error: " + error + " " + M[i][j]);
079                                                sq += error * error;
080                                                final double uTemp = Uprime[i][k];
081                                                final double vTemp = Vprime[j][k];
082                                                // Uprime[i][k] += learningRate[epoch] * ( error * vTemp
083                                                // - regularization * uTemp );
084                                                // Vprime[j][k] += learningRate[epoch] * ( error * uTemp
085                                                // - regularization * vTemp );
086                                                Uprime[i][k] += learningRate * (error * vTemp);
087                                                Vprime[j][k] += learningRate * (error * uTemp);
088
089                                                // System.out.println(i + " " + learningRate * (error *
090                                                // vTemp));
091                                        }
092                                }
093
094                                if (lastError - sq < 0.000001)
095                                        break;
096
097                                lastError = sq;
098                        }
099                }
100
101                UM = new Matrix(m, maxOrder);
102                final double[][] U = UM.getArray();
103                SM = new Matrix(maxOrder, maxOrder);
104                final double[][] S = SM.getArray();
105                VM = new Matrix(maxOrder, n);
106                final double[][] V = VM.getArray();
107                for (int i = 0; i < maxOrder; i++) {
108                        double un = 0;
109                        double vn = 0;
110                        for (int j = 0; j < m; j++) {
111                                un += (Uprime[j][i] * Uprime[j][i]);
112                        }
113                        for (int j = 0; j < n; j++) {
114                                vn += (Vprime[j][i] * Vprime[j][i]);
115                        }
116
117                        un = Math.sqrt(un);
118                        vn = Math.sqrt(vn);
119
120                        for (int j = 0; j < m; j++) {
121                                U[j][i] = Uprime[j][i] / un;
122                        }
123                        for (int j = 0; j < n; j++) {
124                                V[i][j] = Vprime[j][i] / vn;
125                        }
126
127                        S[i][i] = un * vn;
128                }
129        }
130
131        public static void main(String[] args) {
132                // final Matrix m = Matrix.random(10, 10);
133                final Matrix m = new Matrix(new double[][] { { 0.5, 0.4 }, { 0.1, 0.7 } });
134
135                final GD_SVD gdsvd = new GD_SVD(m, 2);
136
137                // m.print(5, 5);
138                // gdsvd.UprimeM.print(5, 5);
139                // gdsvd.UprimeM.times(gdsvd.VprimeM.transpose()).print(5, 5);
140                // gdsvd.UM.times(gdsvd.SM.times(gdsvd.VM)).print(5, 5);
141                // gdsvd.UM.print(5, 5);
142                gdsvd.SM.print(5, 5);
143                // gdsvd.VM.print(5, 5);
144
145                // final ThinSingularValueDecomposition tsvd = new
146                // ThinSingularValueDecomposition(m, 2);
147                // tsvd.U.print(5, 5);
148                // System.out.println(Arrays.toString(tsvd.S));
149
150        }
151}