View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.neuralnet;
31  
32  import gov.sandia.cognition.math.matrix.Matrix;
33  import gov.sandia.cognition.math.matrix.MatrixFactory;
34  import gov.sandia.cognition.math.matrix.Vector;
35  import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
36  
37  import org.openimaj.data.RandomData;
38  import org.openimaj.image.DisplayUtilities;
39  import org.openimaj.image.FImage;
40  import org.openimaj.image.colour.ColourMap;
41  import org.openimaj.util.function.Function;
42  
43  /**
44   * Implement an online version of the backprop algorithm against an 2D
45   * 
46   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
47   *
48   */
49  public class OnlineBackpropOneHidden {
50  
51  	private static final double LEARNRATE = 0.005;
52  	private Matrix weightsL1;
53  	private Matrix weightsL2;
54  	MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault();
55  	private Function<Double, Double> g;
56  	private Function<Matrix, Matrix> gMat;
57  	private Function<Double, Double> gPrime;
58  	private Function<Matrix, Matrix> gPrimeMat;
59  
60  	/**
61  	 * @param nInput
62  	 *            the number of input values
63  	 * @param nHidden
64  	 *            the number of hidden values
65  	 * @param nFinal
66  	 *            the number of final values
67  	 */
68  	public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) {
69  		final double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput + 1, nHidden, -1, 1.);
70  		final double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden + 1, nFinal, -1, 1.);
71  
72  		weightsL1 = DMF.copyArray(weightsL1dat);
73  		weightsL2 = DMF.copyArray(weightsL2dat);
74  		;
75  
76  		g = new Function<Double, Double>() {
77  
78  			@Override
79  			public Double apply(Double in) {
80  
81  				return 1. / (1 + Math.exp(-in));
82  			}
83  
84  		};
85  
86  		gPrime = new Function<Double, Double>() {
87  
88  			@Override
89  			public Double apply(Double in) {
90  
91  				return g.apply(in) * (1 - g.apply(in));
92  			}
93  
94  		};
95  
96  		gPrimeMat = new Function<Matrix, Matrix>() {
97  
98  			@Override
99  			public Matrix apply(Matrix in) {
100 				final Matrix out = DMF.copyMatrix(in);
101 				for (int i = 0; i < in.getNumRows(); i++) {
102 					for (int j = 0; j < in.getNumColumns(); j++) {
103 						out.setElement(i, j, gPrime.apply(in.getElement(i, j)));
104 					}
105 				}
106 				return out;
107 			}
108 
109 		};
110 
111 		gMat = new Function<Matrix, Matrix>() {
112 
113 			@Override
114 			public Matrix apply(Matrix in) {
115 				final Matrix out = DMF.copyMatrix(in);
116 				for (int i = 0; i < in.getNumRows(); i++) {
117 					for (int j = 0; j < in.getNumColumns(); j++) {
118 						out.setElement(i, j, g.apply(in.getElement(i, j)));
119 					}
120 				}
121 				return out;
122 			}
123 
124 		};
125 	}
126 
127 	public void update(double[] x, double[] y) {
128 		final Matrix X = prepareMatrix(x);
129 		final Matrix Y = DMF.copyArray(new double[][] { y });
130 
131 		final Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers
132 																	// x nInputs
133 																	// (usually
134 																	// 2 x 1)
135 		final Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
136 																							// +
137 																							// 1
138 																							// x
139 																							// nInputs
140 																							// (usually
141 																							// 3x1)
142 		final Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
143 																										// +
144 																										// 1
145 																										// x
146 																										// nInputs
147 																										// (usually
148 																										// 3x1)
149 		final Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput);
150 		final Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers
151 																		// x
152 																		// nInputs
153 																		// (usually
154 																		// 1x1)
155 
156 		final Matrix errmat = Y.minus(finalOutput);
157 		final double err = errmat.sumOfColumns().sum();
158 
159 		Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should
160 																											// be
161 																											// nHiddenLayers
162 																											// +
163 																											// 1
164 																											// x
165 																											// nInputs
166 																											// (3
167 																											// x
168 																											// 1)
169 		Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose()))
170 				.scale(err * LEARNRATE).transpose();
171 
172 		dL1 = repmat(dL1, 1, weightsL1.getNumColumns());
173 		dL2 = repmat(dL2, 1, weightsL2.getNumColumns());
174 
175 		this.weightsL1.plusEquals(dL1);
176 		this.weightsL2.plusEquals(dL2);
177 
178 	}
179 
180 	private Matrix repmat(Matrix dL1, int nRows, int nCols) {
181 		final Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns());
182 		for (int i = 0; i < nRows; i++) {
183 			for (int j = 0; j < nCols; j++) {
184 				out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1);
185 			}
186 		}
187 		return out;
188 	}
189 
190 	public Matrix predict(double[] x) {
191 		final Matrix X = prepareMatrix(x);
192 
193 		final Matrix hiddenTimes = weightsL1.transpose().times(X);
194 		final Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0));
195 		final Matrix finalTimes = weightsL2.transpose().times(hiddenVal);
196 		final Matrix finalVal = gMat.apply(finalTimes);
197 
198 		return finalVal;
199 
200 	}
201 
202 	private Matrix prepareMatrix(Vector y) {
203 		final Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1);
204 		Y.setElement(0, 0, 1);
205 		Y.setSubMatrix(0, 1, DMF.copyRowVectors(y));
206 		return Y.transpose();
207 	}
208 
209 	private Matrix prepareMatrix(double[] y) {
210 		final Matrix Y = DMF.createMatrix(1, y.length + 1);
211 		Y.setElement(0, 0, 1);
212 		Y.setSubMatrix(0, 1, DMF.copyArray(new double[][] { y }));
213 		return Y.transpose();
214 	}
215 
216 	public static void main(String[] args) throws InterruptedException {
217 		final OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1);
218 		FImage img = new FImage(200, 200);
219 		img = imagePredict(bp, img);
220 		final ColourMap m = ColourMap.Hot;
221 
222 		DisplayUtilities.displayName(m.apply(img), "xor");
223 		final int npixels = img.width * img.height;
224 		final int half = img.width / 2;
225 		final int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels);
226 		while (true) {
227 			// for (int i = 0; i < pixels.length; i++) {
228 			// int pixel = pixels[i];
229 			// int y = pixel / img.width;
230 			// int x = pixel - (y * img.width);
231 			// bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new
232 			// double[]{xorValue(half,x,y)});
233 			// // Thread.sleep(5);
234 			// }
235 			bp.update(new double[] { 0, 0 }, new double[] { 0 });
236 			bp.update(new double[] { 1, 1 }, new double[] { 0 });
237 			bp.update(new double[] { 0, 1 }, new double[] { 1 });
238 			bp.update(new double[] { 1, 0 }, new double[] { 1 });
239 			imagePredict(bp, img);
240 			DisplayUtilities.displayName(m.apply(img), "xor");
241 		}
242 	}
243 
244 	private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) {
245 		final double[] pos = new double[2];
246 		final int half = img.width / 2;
247 		for (int y = 0; y < img.height; y++) {
248 			for (int x = 0; x < img.width; x++) {
249 				pos[0] = x < half ? 0 : 1;
250 				pos[1] = y < half ? 0 : 1;
251 				final float ret = (float) bp.predict(pos).getElement(0, 0);
252 				img.pixels[y][x] = ret;
253 			}
254 		}
255 		return img;
256 	}
257 }