1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress;
31
32 import java.util.Arrays;
33 import java.util.Random;
34
35 import org.openimaj.data.AbstractDataSource;
36 import org.openimaj.math.matrix.ThinSingularValueDecomposition;
37 import org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction;
38 import org.openimaj.workinprogress.optimisation.EpochAnnealedLearningRate;
39 import org.openimaj.workinprogress.optimisation.SGD;
40 import org.openimaj.workinprogress.optimisation.params.KeyedParameters;
41 import org.openimaj.workinprogress.optimisation.params.KeyedParameters.ObjectDoubleEntry;
42
43 import Jama.Matrix;
44
45 public class GD_SVD2 {
46 static class GD_SVD2_DOF implements DifferentiableObjectiveFunction<GD_SVD2, double[], KeyedParameters<String>> {
47 public int k;
48
49 @Override
50 public double value(GD_SVD2 model, double[] data) {
51 final int i = (int) data[0];
52 final int j = (int) data[1];
53 final int m = (int) data[2];
54
55 final double error = m - model.predict(i, j, k);
56
57 return error * error;
58 }
59
60 @Override
61 public KeyedParameters<String> derivative(GD_SVD2 model, double[] data) {
62 final int i = (int) data[0];
63 final int j = (int) data[1];
64 final double m = data[2];
65 final double[][] Uprime = model.UprimeM.getArray();
66 final double[][] Vprime = model.VprimeM.getArray();
67
68 final double error = m - model.predict(i, j, k);
69 final double uTemp = Uprime[i][k];
70 final double vTemp = Vprime[j][k];
71
72 System.out.println("Error: " + error + " " + m);
73
74 final KeyedParameters<String> params = new KeyedParameters<String>();
75 params.set("i" + i, error * vTemp);
76 params.set("j" + j, error * uTemp);
77 return params;
78 }
79
80 @Override
81 public void updateModel(GD_SVD2 model, KeyedParameters<String> weights) {
82 final double[][] Uprime = model.UprimeM.getArray();
83 final double[][] Vprime = model.VprimeM.getArray();
84
85
86
87
88
89
90
91
92
93 for (final ObjectDoubleEntry<String> e : weights) {
94 final char type = e.key.charAt(0);
95 final int idx = Integer.parseInt(e.key.substring(1));
96
97 if (type == 'i') {
98 Uprime[idx][k] += e.value;
99 } else {
100 Vprime[idx][k] += e.value;
101 }
102 }
103
104
105 }
106 }
107
108 private static final int maxEpochs = 300;
109 private static final double initialLearningRate = 0.01;
110 private static final double annealingRate = maxEpochs * 0.1;
111
112 Matrix UprimeM;
113 Matrix VprimeM;
114 private Matrix UM;
115 private Matrix VM;
116 private Matrix SM;
117
118 protected double predict(int i, int j, int k) {
119 final double[][] Uprime = UprimeM.getArray();
120 final double[][] Vprime = VprimeM.getArray();
121
122 double pred = 0;
123 for (int kk = 0; kk <= k; kk++)
124 pred += Uprime[i][kk] * Vprime[j][kk];
125
126 return pred;
127 }
128
129 public GD_SVD2(Matrix MM, int maxOrder) {
130 final Random random = new Random(0);
131 final double initValue = 1 / Math.sqrt(maxOrder);
132
133 final int m = MM.getRowDimension();
134 final int n = MM.getColumnDimension();
135
136 UprimeM = new Matrix(m, maxOrder);
137 VprimeM = new Matrix(n, maxOrder);
138 final double[][] Uprime = UprimeM.getArray();
139 final double[][] Vprime = VprimeM.getArray();
140 final double[][] M = MM.getArray();
141
142 final SGD<GD_SVD2, double[], KeyedParameters<String>> sgd = new SGD<GD_SVD2, double[], KeyedParameters<String>>();
143 sgd.fcn = new GD_SVD2_DOF();
144 sgd.batchSize = 1;
145 sgd.maxEpochs = 300;
146 sgd.learningRate = new EpochAnnealedLearningRate(0.01, 300);
147 sgd.model = this;
148
149 for (((GD_SVD2_DOF) sgd.fcn).k = 0; ((GD_SVD2_DOF) sgd.fcn).k < maxOrder; ((GD_SVD2_DOF) sgd.fcn).k++) {
150 for (int i = 0; i < m; i++)
151 Uprime[i][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
152 for (int j = 0; j < n; j++)
153 Vprime[j][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * initValue;
154
155 sgd.train(new AbstractDataSource<double[]>() {
156
157 @Override
158 public void getData(int startRow, int stopRow, double[][] data) {
159 for (int idx = startRow, kkk = 0; idx < stopRow; idx++, kkk++) {
160 final int row = idx / M[0].length;
161 final int col = idx % M[0].length;
162 data[kkk][0] = row;
163 data[kkk][1] = col;
164 data[kkk][2] = M[row][col];
165 }
166 }
167
168 @Override
169 public double[] getData(int idx) {
170 final int row = idx / M[0].length;
171 final int col = idx % M[0].length;
172
173 return new double[] { row, col, M[row][col] };
174 }
175
176 @Override
177 public int numDimensions() {
178 return 3;
179 }
180
181 @Override
182 public int size() {
183 return M[0].length * M.length;
184 }
185
186 @Override
187 public double[][] createTemporaryArray(int size) {
188 return new double[size][3];
189 }
190
191 });
192 }
193
194 UM = new Matrix(m, maxOrder);
195 final double[][] U = UM.getArray();
196 SM = new Matrix(maxOrder, maxOrder);
197 final double[][] S = SM.getArray();
198 VM = new Matrix(maxOrder, n);
199 final double[][] V = VM.getArray();
200 for (int i = 0; i < maxOrder; i++) {
201 double un = 0;
202 double vn = 0;
203 for (int j = 0; j < m; j++) {
204 un += (Uprime[j][i] * Uprime[j][i]);
205 }
206 for (int j = 0; j < n; j++) {
207 vn += (Vprime[j][i] * Vprime[j][i]);
208 }
209
210 un = Math.sqrt(un);
211 vn = Math.sqrt(vn);
212
213 for (int j = 0; j < m; j++) {
214 U[j][i] = Uprime[j][i] / un;
215 }
216 for (int j = 0; j < n; j++) {
217 V[i][j] = Vprime[j][i] / vn;
218 }
219
220 S[i][i] = un * vn;
221 }
222 }
223
224 public static void main(String[] args) {
225 final Matrix m = new Matrix(new double[][] { { 0.5, 0.4 }, { 0.1, 0.7 } });
226
227 final GD_SVD2 gdsvd = new GD_SVD2(m, 2);
228
229
230
231
232
233
234 gdsvd.SM.print(5, 5);
235
236
237 final ThinSingularValueDecomposition tsvd = new ThinSingularValueDecomposition(m, 2);
238
239 System.out.println(Arrays.toString(tsvd.S));
240 }
241 }