1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.neuralnet;
31
32 import gov.sandia.cognition.math.matrix.Matrix;
33 import gov.sandia.cognition.math.matrix.MatrixFactory;
34 import gov.sandia.cognition.math.matrix.Vector;
35 import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
36
37 import org.openimaj.data.RandomData;
38 import org.openimaj.image.DisplayUtilities;
39 import org.openimaj.image.FImage;
40 import org.openimaj.image.colour.ColourMap;
41 import org.openimaj.util.function.Function;
42
43
44
45
46
47
48
49 public class OnlineBackpropOneHidden {
50
51 private static final double LEARNRATE = 0.005;
52 private Matrix weightsL1;
53 private Matrix weightsL2;
54 MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault();
55 private Function<Double, Double> g;
56 private Function<Matrix, Matrix> gMat;
57 private Function<Double, Double> gPrime;
58 private Function<Matrix, Matrix> gPrimeMat;
59
60
61
62
63
64
65
66
67
68 public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) {
69 final double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput + 1, nHidden, -1, 1.);
70 final double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden + 1, nFinal, -1, 1.);
71
72 weightsL1 = DMF.copyArray(weightsL1dat);
73 weightsL2 = DMF.copyArray(weightsL2dat);
74 ;
75
76 g = new Function<Double, Double>() {
77
78 @Override
79 public Double apply(Double in) {
80
81 return 1. / (1 + Math.exp(-in));
82 }
83
84 };
85
86 gPrime = new Function<Double, Double>() {
87
88 @Override
89 public Double apply(Double in) {
90
91 return g.apply(in) * (1 - g.apply(in));
92 }
93
94 };
95
96 gPrimeMat = new Function<Matrix, Matrix>() {
97
98 @Override
99 public Matrix apply(Matrix in) {
100 final Matrix out = DMF.copyMatrix(in);
101 for (int i = 0; i < in.getNumRows(); i++) {
102 for (int j = 0; j < in.getNumColumns(); j++) {
103 out.setElement(i, j, gPrime.apply(in.getElement(i, j)));
104 }
105 }
106 return out;
107 }
108
109 };
110
111 gMat = new Function<Matrix, Matrix>() {
112
113 @Override
114 public Matrix apply(Matrix in) {
115 final Matrix out = DMF.copyMatrix(in);
116 for (int i = 0; i < in.getNumRows(); i++) {
117 for (int j = 0; j < in.getNumColumns(); j++) {
118 out.setElement(i, j, g.apply(in.getElement(i, j)));
119 }
120 }
121 return out;
122 }
123
124 };
125 }
126
127 public void update(double[] x, double[] y) {
128 final Matrix X = prepareMatrix(x);
129 final Matrix Y = DMF.copyArray(new double[][] { y });
130
131 final Matrix hiddenOutput = weightsL1.transpose().times(X);
132
133
134
135 final Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0));
136
137
138
139
140
141
142 final Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0));
143
144
145
146
147
148
149 final Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput);
150 final Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput);
151
152
153
154
155
156 final Matrix errmat = Y.minus(finalOutput);
157 final double err = errmat.sumOfColumns().sum();
158
159 Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose();
160
161
162
163
164
165
166
167
168
169 Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose()))
170 .scale(err * LEARNRATE).transpose();
171
172 dL1 = repmat(dL1, 1, weightsL1.getNumColumns());
173 dL2 = repmat(dL2, 1, weightsL2.getNumColumns());
174
175 this.weightsL1.plusEquals(dL1);
176 this.weightsL2.plusEquals(dL2);
177
178 }
179
180 private Matrix repmat(Matrix dL1, int nRows, int nCols) {
181 final Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns());
182 for (int i = 0; i < nRows; i++) {
183 for (int j = 0; j < nCols; j++) {
184 out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1);
185 }
186 }
187 return out;
188 }
189
190 public Matrix predict(double[] x) {
191 final Matrix X = prepareMatrix(x);
192
193 final Matrix hiddenTimes = weightsL1.transpose().times(X);
194 final Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0));
195 final Matrix finalTimes = weightsL2.transpose().times(hiddenVal);
196 final Matrix finalVal = gMat.apply(finalTimes);
197
198 return finalVal;
199
200 }
201
202 private Matrix prepareMatrix(Vector y) {
203 final Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1);
204 Y.setElement(0, 0, 1);
205 Y.setSubMatrix(0, 1, DMF.copyRowVectors(y));
206 return Y.transpose();
207 }
208
209 private Matrix prepareMatrix(double[] y) {
210 final Matrix Y = DMF.createMatrix(1, y.length + 1);
211 Y.setElement(0, 0, 1);
212 Y.setSubMatrix(0, 1, DMF.copyArray(new double[][] { y }));
213 return Y.transpose();
214 }
215
216 public static void main(String[] args) throws InterruptedException {
217 final OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1);
218 FImage img = new FImage(200, 200);
219 img = imagePredict(bp, img);
220 final ColourMap m = ColourMap.Hot;
221
222 DisplayUtilities.displayName(m.apply(img), "xor");
223 final int npixels = img.width * img.height;
224 final int half = img.width / 2;
225 final int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels);
226 while (true) {
227
228
229
230
231
232
233
234
235 bp.update(new double[] { 0, 0 }, new double[] { 0 });
236 bp.update(new double[] { 1, 1 }, new double[] { 0 });
237 bp.update(new double[] { 0, 1 }, new double[] { 1 });
238 bp.update(new double[] { 1, 0 }, new double[] { 1 });
239 imagePredict(bp, img);
240 DisplayUtilities.displayName(m.apply(img), "xor");
241 }
242 }
243
244 private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) {
245 final double[] pos = new double[2];
246 final int half = img.width / 2;
247 for (int y = 0; y < img.height; y++) {
248 for (int x = 0; x < img.width; x++) {
249 pos[0] = x < half ? 0 : 1;
250 pos[1] = y < half ? 0 : 1;
251 final float ret = (float) bp.predict(pos).getElement(0, 0);
252 img.pixels[y][x] = ret;
253 }
254 }
255 return img;
256 }
257 }