001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress;
031
032import java.util.Arrays;
033import java.util.Random;
034
035import org.openimaj.data.AbstractDataSource;
036import org.openimaj.math.matrix.ThinSingularValueDecomposition;
037import org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction;
038import org.openimaj.workinprogress.optimisation.EpochAnnealedLearningRate;
039import org.openimaj.workinprogress.optimisation.SGD;
040import org.openimaj.workinprogress.optimisation.params.KeyedParameters;
041import org.openimaj.workinprogress.optimisation.params.KeyedParameters.ObjectDoubleEntry;
042
043import Jama.Matrix;
044
045public class GD_SVD2 {
046        static class GD_SVD2_DOF implements DifferentiableObjectiveFunction<GD_SVD2, double[], KeyedParameters<String>> {
047                public int k;
048
049                @Override
050                public double value(GD_SVD2 model, double[] data) {
051                        final int i = (int) data[0];
052                        final int j = (int) data[1];
053                        final int m = (int) data[2];
054
055                        final double error = m - model.predict(i, j, k);
056
057                        return error * error;
058                }
059
060                @Override
061                public KeyedParameters<String> derivative(GD_SVD2 model, double[] data) {
062                        final int i = (int) data[0];
063                        final int j = (int) data[1];
064                        final double m = data[2];
065                        final double[][] Uprime = model.UprimeM.getArray();
066                        final double[][] Vprime = model.VprimeM.getArray();
067
068                        final double error = m - model.predict(i, j, k);
069                        final double uTemp = Uprime[i][k];
070                        final double vTemp = Vprime[j][k];
071
072                        System.out.println("Error: " + error + " " + m);
073
074                        final KeyedParameters<String> params = new KeyedParameters<String>();
075                        params.set("i" + i, error * vTemp);
076                        params.set("j" + j, error * uTemp);
077                        return params;
078                }
079
080                @Override
081                public void updateModel(GD_SVD2 model, KeyedParameters<String> weights) {
082                        final double[][] Uprime = model.UprimeM.getArray();
083                        final double[][] Vprime = model.VprimeM.getArray();
084
085                        // for (final Entry e : weights.firstObject().entries()) {
086                        // Uprime[e.index][k] += e.value;
087                        //
088                        // System.out.println(e.index + " " + e.value);
089                        // }
090                        // for (final Entry e : weights.secondObject().entries()) {
091                        // Vprime[e.index][k] += e.value;
092                        // }
093                        for (final ObjectDoubleEntry<String> e : weights) {
094                                final char type = e.key.charAt(0);
095                                final int idx = Integer.parseInt(e.key.substring(1));
096
097                                if (type == 'i') {
098                                        Uprime[idx][k] += e.value;
099                                } else {
100                                        Vprime[idx][k] += e.value;
101                                }
102                        }
103
104                        // model.UprimeM.print(5, 5);
105                }
106        }
107
108        private static final int maxEpochs = 300;
109        private static final double initialLearningRate = 0.01;
110        private static final double annealingRate = maxEpochs * 0.1;
111
112        Matrix UprimeM;
113        Matrix VprimeM;
114        private Matrix UM;
115        private Matrix VM;
116        private Matrix SM;
117
118        protected double predict(int i, int j, int k) {
119                final double[][] Uprime = UprimeM.getArray();
120                final double[][] Vprime = VprimeM.getArray();
121
122                double pred = 0;
123                for (int kk = 0; kk <= k; kk++)
124                        pred += Uprime[i][kk] * Vprime[j][kk];
125
126                return pred;
127        }
128
129        public GD_SVD2(Matrix MM, int maxOrder) {
130                final Random random = new Random(0);
131                final double initValue = 1 / Math.sqrt(maxOrder);
132
133                final int m = MM.getRowDimension();
134                final int n = MM.getColumnDimension();
135
136                UprimeM = new Matrix(m, maxOrder);
137                VprimeM = new Matrix(n, maxOrder);
138                final double[][] Uprime = UprimeM.getArray();
139                final double[][] Vprime = VprimeM.getArray();
140                final double[][] M = MM.getArray();
141
142                final SGD<GD_SVD2, double[], KeyedParameters<String>> sgd = new SGD<GD_SVD2, double[], KeyedParameters<String>>();
143                sgd.fcn = new GD_SVD2_DOF();
144                sgd.batchSize = 1;
145                sgd.maxEpochs = 300;
146                sgd.learningRate = new EpochAnnealedLearningRate(0.01, 300);
147                sgd.model = this;
148
149                for (((GD_SVD2_DOF) sgd.fcn).k = 0; ((GD_SVD2_DOF) sgd.fcn).k < maxOrder; ((GD_SVD2_DOF) sgd.fcn).k++) {
150                        for (int i = 0; i < m; i++)
151                                Uprime[i][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
152                        for (int j = 0; j < n; j++)
153                                Vprime[j][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
154
155                        sgd.train(new AbstractDataSource<double[]>() {
156
157                                @Override
158                                public void getData(int startRow, int stopRow, double[][] data) {
159                                        for (int idx = startRow, kkk = 0; idx < stopRow; idx++, kkk++) {
160                                                final int row = idx / M[0].length;
161                                                final int col = idx % M[0].length;
162                                                data[kkk][0] = row;
163                                                data[kkk][1] = col;
164                                                data[kkk][2] = M[row][col];
165                                        }
166                                }
167
168                                @Override
169                                public double[] getData(int idx) {
170                                        final int row = idx / M[0].length;
171                                        final int col = idx % M[0].length;
172
173                                        return new double[] { row, col, M[row][col] };
174                                }
175
176                                @Override
177                                public int numDimensions() {
178                                        return 3;
179                                }
180
181                                @Override
182                                public int size() {
183                                        return M[0].length * M.length;
184                                }
185
186                                @Override
187                                public double[][] createTemporaryArray(int size) {
188                                        return new double[size][3];
189                                }
190
191                        });
192                }
193
194                UM = new Matrix(m, maxOrder);
195                final double[][] U = UM.getArray();
196                SM = new Matrix(maxOrder, maxOrder);
197                final double[][] S = SM.getArray();
198                VM = new Matrix(maxOrder, n);
199                final double[][] V = VM.getArray();
200                for (int i = 0; i < maxOrder; i++) {
201                        double un = 0;
202                        double vn = 0;
203                        for (int j = 0; j < m; j++) {
204                                un += (Uprime[j][i] * Uprime[j][i]);
205                        }
206                        for (int j = 0; j < n; j++) {
207                                vn += (Vprime[j][i] * Vprime[j][i]);
208                        }
209
210                        un = Math.sqrt(un);
211                        vn = Math.sqrt(vn);
212
213                        for (int j = 0; j < m; j++) {
214                                U[j][i] = Uprime[j][i] / un;
215                        }
216                        for (int j = 0; j < n; j++) {
217                                V[i][j] = Vprime[j][i] / vn;
218                        }
219
220                        S[i][i] = un * vn;
221                }
222        }
223
224        public static void main(String[] args) {
225                final Matrix m = new Matrix(new double[][] { { 0.5, 0.4 }, { 0.1, 0.7 } });
226
227                final GD_SVD2 gdsvd = new GD_SVD2(m, 2);
228
229                // m.print(5, 5);
230                // gdsvd.UprimeM.print(5, 5);
231                // gdsvd.UprimeM.times(gdsvd.VprimeM.transpose()).print(5, 5);
232                // gdsvd.UM.times(gdsvd.SM.times(gdsvd.VM)).print(5, 5);
233                // gdsvd.UM.print(5, 5);
234                gdsvd.SM.print(5, 5);
235                // gdsvd.VM.print(5, 5);
236
237                final ThinSingularValueDecomposition tsvd = new ThinSingularValueDecomposition(m, 2);
238                // tsvd.U.print(5, 5);
239                System.out.println(Arrays.toString(tsvd.S));
240        }
241}