001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.neuralnet;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.MatrixFactory;
034import gov.sandia.cognition.math.matrix.Vector;
035import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
036
037import org.openimaj.data.RandomData;
038import org.openimaj.image.DisplayUtilities;
039import org.openimaj.image.FImage;
040import org.openimaj.image.colour.ColourMap;
041import org.openimaj.util.function.Function;
042
043/**
044 * Implement an online version of the backprop algorithm against an 2D
045 * 
046 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
047 *
048 */
049public class OnlineBackpropOneHidden {
050
051        private static final double LEARNRATE = 0.005;
052        private Matrix weightsL1;
053        private Matrix weightsL2;
054        MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault();
055        private Function<Double, Double> g;
056        private Function<Matrix, Matrix> gMat;
057        private Function<Double, Double> gPrime;
058        private Function<Matrix, Matrix> gPrimeMat;
059
060        /**
061         * @param nInput
062         *            the number of input values
063         * @param nHidden
064         *            the number of hidden values
065         * @param nFinal
066         *            the number of final values
067         */
068        public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) {
069                final double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput + 1, nHidden, -1, 1.);
070                final double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden + 1, nFinal, -1, 1.);
071
072                weightsL1 = DMF.copyArray(weightsL1dat);
073                weightsL2 = DMF.copyArray(weightsL2dat);
074                ;
075
076                g = new Function<Double, Double>() {
077
078                        @Override
079                        public Double apply(Double in) {
080
081                                return 1. / (1 + Math.exp(-in));
082                        }
083
084                };
085
086                gPrime = new Function<Double, Double>() {
087
088                        @Override
089                        public Double apply(Double in) {
090
091                                return g.apply(in) * (1 - g.apply(in));
092                        }
093
094                };
095
096                gPrimeMat = new Function<Matrix, Matrix>() {
097
098                        @Override
099                        public Matrix apply(Matrix in) {
100                                final Matrix out = DMF.copyMatrix(in);
101                                for (int i = 0; i < in.getNumRows(); i++) {
102                                        for (int j = 0; j < in.getNumColumns(); j++) {
103                                                out.setElement(i, j, gPrime.apply(in.getElement(i, j)));
104                                        }
105                                }
106                                return out;
107                        }
108
109                };
110
111                gMat = new Function<Matrix, Matrix>() {
112
113                        @Override
114                        public Matrix apply(Matrix in) {
115                                final Matrix out = DMF.copyMatrix(in);
116                                for (int i = 0; i < in.getNumRows(); i++) {
117                                        for (int j = 0; j < in.getNumColumns(); j++) {
118                                                out.setElement(i, j, g.apply(in.getElement(i, j)));
119                                        }
120                                }
121                                return out;
122                        }
123
124                };
125        }
126
127        public void update(double[] x, double[] y) {
128                final Matrix X = prepareMatrix(x);
129                final Matrix Y = DMF.copyArray(new double[][] { y });
130
131                final Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers
132                                                                                                                                        // x nInputs
133                                                                                                                                        // (usually
134                                                                                                                                        // 2 x 1)
135                final Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
136                                                                                                                                                                                        // +
137                                                                                                                                                                                        // 1
138                                                                                                                                                                                        // x
139                                                                                                                                                                                        // nInputs
140                                                                                                                                                                                        // (usually
141                                                                                                                                                                                        // 3x1)
142                final Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers
143                                                                                                                                                                                                                // +
144                                                                                                                                                                                                                // 1
145                                                                                                                                                                                                                // x
146                                                                                                                                                                                                                // nInputs
147                                                                                                                                                                                                                // (usually
148                                                                                                                                                                                                                // 3x1)
149                final Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput);
150                final Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers
151                                                                                                                                                // x
152                                                                                                                                                // nInputs
153                                                                                                                                                // (usually
154                                                                                                                                                // 1x1)
155
156                final Matrix errmat = Y.minus(finalOutput);
157                final double err = errmat.sumOfColumns().sum();
158
159                Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should
160                                                                                                                                                                                                                        // be
161                                                                                                                                                                                                                        // nHiddenLayers
162                                                                                                                                                                                                                        // +
163                                                                                                                                                                                                                        // 1
164                                                                                                                                                                                                                        // x
165                                                                                                                                                                                                                        // nInputs
166                                                                                                                                                                                                                        // (3
167                                                                                                                                                                                                                        // x
168                                                                                                                                                                                                                        // 1)
169                Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose()))
170                                .scale(err * LEARNRATE).transpose();
171
172                dL1 = repmat(dL1, 1, weightsL1.getNumColumns());
173                dL2 = repmat(dL2, 1, weightsL2.getNumColumns());
174
175                this.weightsL1.plusEquals(dL1);
176                this.weightsL2.plusEquals(dL2);
177
178        }
179
180        private Matrix repmat(Matrix dL1, int nRows, int nCols) {
181                final Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns());
182                for (int i = 0; i < nRows; i++) {
183                        for (int j = 0; j < nCols; j++) {
184                                out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1);
185                        }
186                }
187                return out;
188        }
189
190        public Matrix predict(double[] x) {
191                final Matrix X = prepareMatrix(x);
192
193                final Matrix hiddenTimes = weightsL1.transpose().times(X);
194                final Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0));
195                final Matrix finalTimes = weightsL2.transpose().times(hiddenVal);
196                final Matrix finalVal = gMat.apply(finalTimes);
197
198                return finalVal;
199
200        }
201
202        private Matrix prepareMatrix(Vector y) {
203                final Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1);
204                Y.setElement(0, 0, 1);
205                Y.setSubMatrix(0, 1, DMF.copyRowVectors(y));
206                return Y.transpose();
207        }
208
209        private Matrix prepareMatrix(double[] y) {
210                final Matrix Y = DMF.createMatrix(1, y.length + 1);
211                Y.setElement(0, 0, 1);
212                Y.setSubMatrix(0, 1, DMF.copyArray(new double[][] { y }));
213                return Y.transpose();
214        }
215
216        public static void main(String[] args) throws InterruptedException {
217                final OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1);
218                FImage img = new FImage(200, 200);
219                img = imagePredict(bp, img);
220                final ColourMap m = ColourMap.Hot;
221
222                DisplayUtilities.displayName(m.apply(img), "xor");
223                final int npixels = img.width * img.height;
224                final int half = img.width / 2;
225                final int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels);
226                while (true) {
227                        // for (int i = 0; i < pixels.length; i++) {
228                        // int pixel = pixels[i];
229                        // int y = pixel / img.width;
230                        // int x = pixel - (y * img.width);
231                        // bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new
232                        // double[]{xorValue(half,x,y)});
233                        // // Thread.sleep(5);
234                        // }
235                        bp.update(new double[] { 0, 0 }, new double[] { 0 });
236                        bp.update(new double[] { 1, 1 }, new double[] { 0 });
237                        bp.update(new double[] { 0, 1 }, new double[] { 1 });
238                        bp.update(new double[] { 1, 0 }, new double[] { 1 });
239                        imagePredict(bp, img);
240                        DisplayUtilities.displayName(m.apply(img), "xor");
241                }
242        }
243
244        private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) {
245                final double[] pos = new double[2];
246                final int half = img.width / 2;
247                for (int y = 0; y < img.height; y++) {
248                        for (int x = 0; x < img.width; x++) {
249                                pos[0] = x < half ? 0 : 1;
250                                pos[1] = y < half ? 0 : 1;
251                                final float ret = (float) bp.predict(pos).getElement(0, 0);
252                                img.pixels[y][x] = ret;
253                        }
254                }
255                return img;
256        }
257}