View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress;
31  
32  import java.util.Arrays;
33  import java.util.Random;
34  
35  import org.openimaj.data.AbstractDataSource;
36  import org.openimaj.math.matrix.ThinSingularValueDecomposition;
37  import org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction;
38  import org.openimaj.workinprogress.optimisation.EpochAnnealedLearningRate;
39  import org.openimaj.workinprogress.optimisation.SGD;
40  import org.openimaj.workinprogress.optimisation.params.KeyedParameters;
41  import org.openimaj.workinprogress.optimisation.params.KeyedParameters.ObjectDoubleEntry;
42  
43  import Jama.Matrix;
44  
45  public class GD_SVD2 {
46  	static class GD_SVD2_DOF implements DifferentiableObjectiveFunction<GD_SVD2, double[], KeyedParameters<String>> {
47  		public int k;
48  
49  		@Override
50  		public double value(GD_SVD2 model, double[] data) {
51  			final int i = (int) data[0];
52  			final int j = (int) data[1];
53  			final int m = (int) data[2];
54  
55  			final double error = m - model.predict(i, j, k);
56  
57  			return error * error;
58  		}
59  
60  		@Override
61  		public KeyedParameters<String> derivative(GD_SVD2 model, double[] data) {
62  			final int i = (int) data[0];
63  			final int j = (int) data[1];
64  			final double m = data[2];
65  			final double[][] Uprime = model.UprimeM.getArray();
66  			final double[][] Vprime = model.VprimeM.getArray();
67  
68  			final double error = m - model.predict(i, j, k);
69  			final double uTemp = Uprime[i][k];
70  			final double vTemp = Vprime[j][k];
71  
72  			System.out.println("Error: " + error + " " + m);
73  
74  			final KeyedParameters<String> params = new KeyedParameters<String>();
75  			params.set("i" + i, error * vTemp);
76  			params.set("j" + j, error * uTemp);
77  			return params;
78  		}
79  
80  		@Override
81  		public void updateModel(GD_SVD2 model, KeyedParameters<String> weights) {
82  			final double[][] Uprime = model.UprimeM.getArray();
83  			final double[][] Vprime = model.VprimeM.getArray();
84  
85  			// for (final Entry e : weights.firstObject().entries()) {
86  			// Uprime[e.index][k] += e.value;
87  			//
88  			// System.out.println(e.index + " " + e.value);
89  			// }
90  			// for (final Entry e : weights.secondObject().entries()) {
91  			// Vprime[e.index][k] += e.value;
92  			// }
93  			for (final ObjectDoubleEntry<String> e : weights) {
94  				final char type = e.key.charAt(0);
95  				final int idx = Integer.parseInt(e.key.substring(1));
96  
97  				if (type == 'i') {
98  					Uprime[idx][k] += e.value;
99  				} else {
100 					Vprime[idx][k] += e.value;
101 				}
102 			}
103 
104 			// model.UprimeM.print(5, 5);
105 		}
106 	}
107 
108 	private static final int maxEpochs = 300;
109 	private static final double initialLearningRate = 0.01;
110 	private static final double annealingRate = maxEpochs * 0.1;
111 
112 	Matrix UprimeM;
113 	Matrix VprimeM;
114 	private Matrix UM;
115 	private Matrix VM;
116 	private Matrix SM;
117 
118 	protected double predict(int i, int j, int k) {
119 		final double[][] Uprime = UprimeM.getArray();
120 		final double[][] Vprime = VprimeM.getArray();
121 
122 		double pred = 0;
123 		for (int kk = 0; kk <= k; kk++)
124 			pred += Uprime[i][kk] * Vprime[j][kk];
125 
126 		return pred;
127 	}
128 
129 	public GD_SVD2(Matrix MM, int maxOrder) {
130 		final Random random = new Random(0);
131 		final double initValue = 1 / Math.sqrt(maxOrder);
132 
133 		final int m = MM.getRowDimension();
134 		final int n = MM.getColumnDimension();
135 
136 		UprimeM = new Matrix(m, maxOrder);
137 		VprimeM = new Matrix(n, maxOrder);
138 		final double[][] Uprime = UprimeM.getArray();
139 		final double[][] Vprime = VprimeM.getArray();
140 		final double[][] M = MM.getArray();
141 
142 		final SGD<GD_SVD2, double[], KeyedParameters<String>> sgd = new SGD<GD_SVD2, double[], KeyedParameters<String>>();
143 		sgd.fcn = new GD_SVD2_DOF();
144 		sgd.batchSize = 1;
145 		sgd.maxEpochs = 300;
146 		sgd.learningRate = new EpochAnnealedLearningRate(0.01, 300);
147 		sgd.model = this;
148 
149 		for (((GD_SVD2_DOF) sgd.fcn).k = 0; ((GD_SVD2_DOF) sgd.fcn).k < maxOrder; ((GD_SVD2_DOF) sgd.fcn).k++) {
150 			for (int i = 0; i < m; i++)
151 				Uprime[i][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
152 			for (int j = 0; j < n; j++)
153 				Vprime[j][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
154 
155 			sgd.train(new AbstractDataSource<double[]>() {
156 
157 				@Override
158 				public void getData(int startRow, int stopRow, double[][] data) {
159 					for (int idx = startRow, kkk = 0; idx < stopRow; idx++, kkk++) {
160 						final int row = idx / M[0].length;
161 						final int col = idx % M[0].length;
162 						data[kkk][0] = row;
163 						data[kkk][1] = col;
164 						data[kkk][2] = M[row][col];
165 					}
166 				}
167 
168 				@Override
169 				public double[] getData(int idx) {
170 					final int row = idx / M[0].length;
171 					final int col = idx % M[0].length;
172 
173 					return new double[] { row, col, M[row][col] };
174 				}
175 
176 				@Override
177 				public int numDimensions() {
178 					return 3;
179 				}
180 
181 				@Override
182 				public int size() {
183 					return M[0].length * M.length;
184 				}
185 
186 				@Override
187 				public double[][] createTemporaryArray(int size) {
188 					return new double[size][3];
189 				}
190 
191 			});
192 		}
193 
194 		UM = new Matrix(m, maxOrder);
195 		final double[][] U = UM.getArray();
196 		SM = new Matrix(maxOrder, maxOrder);
197 		final double[][] S = SM.getArray();
198 		VM = new Matrix(maxOrder, n);
199 		final double[][] V = VM.getArray();
200 		for (int i = 0; i < maxOrder; i++) {
201 			double un = 0;
202 			double vn = 0;
203 			for (int j = 0; j < m; j++) {
204 				un += (Uprime[j][i] * Uprime[j][i]);
205 			}
206 			for (int j = 0; j < n; j++) {
207 				vn += (Vprime[j][i] * Vprime[j][i]);
208 			}
209 
210 			un = Math.sqrt(un);
211 			vn = Math.sqrt(vn);
212 
213 			for (int j = 0; j < m; j++) {
214 				U[j][i] = Uprime[j][i] / un;
215 			}
216 			for (int j = 0; j < n; j++) {
217 				V[i][j] = Vprime[j][i] / vn;
218 			}
219 
220 			S[i][i] = un * vn;
221 		}
222 	}
223 
224 	public static void main(String[] args) {
225 		final Matrix m = new Matrix(new double[][] { { 0.5, 0.4 }, { 0.1, 0.7 } });
226 
227 		final GD_SVD2 gdsvd = new GD_SVD2(m, 2);
228 
229 		// m.print(5, 5);
230 		// gdsvd.UprimeM.print(5, 5);
231 		// gdsvd.UprimeM.times(gdsvd.VprimeM.transpose()).print(5, 5);
232 		// gdsvd.UM.times(gdsvd.SM.times(gdsvd.VM)).print(5, 5);
233 		// gdsvd.UM.print(5, 5);
234 		gdsvd.SM.print(5, 5);
235 		// gdsvd.VM.print(5, 5);
236 
237 		final ThinSingularValueDecomposition tsvd = new ThinSingularValueDecomposition(m, 2);
238 		// tsvd.U.print(5, 5);
239 		System.out.println(Arrays.toString(tsvd.S));
240 	}
241 }