001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.neuralnet; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.MatrixFactory; 034import gov.sandia.cognition.math.matrix.Vector; 035import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ; 036 037import org.openimaj.data.RandomData; 038import org.openimaj.image.DisplayUtilities; 039import org.openimaj.image.FImage; 040import org.openimaj.image.colour.ColourMap; 041import org.openimaj.util.function.Function; 042 043/** 044 * Implement an online version of the backprop algorithm against an 2D 045 * 046 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 047 * 048 */ 049public class OnlineBackpropOneHidden { 050 051 private static final double LEARNRATE = 0.005; 052 private Matrix weightsL1; 053 private Matrix weightsL2; 054 MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault(); 055 private Function<Double, Double> g; 056 private Function<Matrix, Matrix> gMat; 057 private Function<Double, Double> gPrime; 058 private Function<Matrix, Matrix> gPrimeMat; 059 060 /** 061 * @param nInput 062 * the number of input values 063 * @param nHidden 064 * the number of hidden values 065 * @param nFinal 066 * the number of final values 067 */ 068 public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) { 069 final double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput + 1, nHidden, -1, 1.); 070 final double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden + 1, nFinal, -1, 1.); 071 072 weightsL1 = DMF.copyArray(weightsL1dat); 073 weightsL2 = DMF.copyArray(weightsL2dat); 074 ; 075 076 g = new Function<Double, Double>() { 077 078 @Override 079 public Double apply(Double in) { 080 081 return 1. / (1 + Math.exp(-in)); 082 } 083 084 }; 085 086 gPrime = new Function<Double, Double>() { 087 088 @Override 089 public Double apply(Double in) { 090 091 return g.apply(in) * (1 - g.apply(in)); 092 } 093 094 }; 095 096 gPrimeMat = new Function<Matrix, Matrix>() { 097 098 @Override 099 public Matrix apply(Matrix in) { 100 final Matrix out = DMF.copyMatrix(in); 101 for (int i = 0; i < in.getNumRows(); i++) { 102 for (int j = 0; j < in.getNumColumns(); j++) { 103 out.setElement(i, j, gPrime.apply(in.getElement(i, j))); 104 } 105 } 106 return out; 107 } 108 109 }; 110 111 gMat = new Function<Matrix, Matrix>() { 112 113 @Override 114 public Matrix apply(Matrix in) { 115 final Matrix out = DMF.copyMatrix(in); 116 for (int i = 0; i < in.getNumRows(); i++) { 117 for (int j = 0; j < in.getNumColumns(); j++) { 118 out.setElement(i, j, g.apply(in.getElement(i, j))); 119 } 120 } 121 return out; 122 } 123 124 }; 125 } 126 127 public void update(double[] x, double[] y) { 128 final Matrix X = prepareMatrix(x); 129 final Matrix Y = DMF.copyArray(new double[][] { y }); 130 131 final Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers 132 // x nInputs 133 // (usually 134 // 2 x 1) 135 final Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers 136 // + 137 // 1 138 // x 139 // nInputs 140 // (usually 141 // 3x1) 142 final Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers 143 // + 144 // 1 145 // x 146 // nInputs 147 // (usually 148 // 3x1) 149 final Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput); 150 final Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers 151 // x 152 // nInputs 153 // (usually 154 // 1x1) 155 156 final Matrix errmat = Y.minus(finalOutput); 157 final double err = errmat.sumOfColumns().sum(); 158 159 Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should 160 // be 161 // nHiddenLayers 162 // + 163 // 1 164 // x 165 // nInputs 166 // (3 167 // x 168 // 1) 169 Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose())) 170 .scale(err * LEARNRATE).transpose(); 171 172 dL1 = repmat(dL1, 1, weightsL1.getNumColumns()); 173 dL2 = repmat(dL2, 1, weightsL2.getNumColumns()); 174 175 this.weightsL1.plusEquals(dL1); 176 this.weightsL2.plusEquals(dL2); 177 178 } 179 180 private Matrix repmat(Matrix dL1, int nRows, int nCols) { 181 final Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns()); 182 for (int i = 0; i < nRows; i++) { 183 for (int j = 0; j < nCols; j++) { 184 out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1); 185 } 186 } 187 return out; 188 } 189 190 public Matrix predict(double[] x) { 191 final Matrix X = prepareMatrix(x); 192 193 final Matrix hiddenTimes = weightsL1.transpose().times(X); 194 final Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0)); 195 final Matrix finalTimes = weightsL2.transpose().times(hiddenVal); 196 final Matrix finalVal = gMat.apply(finalTimes); 197 198 return finalVal; 199 200 } 201 202 private Matrix prepareMatrix(Vector y) { 203 final Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1); 204 Y.setElement(0, 0, 1); 205 Y.setSubMatrix(0, 1, DMF.copyRowVectors(y)); 206 return Y.transpose(); 207 } 208 209 private Matrix prepareMatrix(double[] y) { 210 final Matrix Y = DMF.createMatrix(1, y.length + 1); 211 Y.setElement(0, 0, 1); 212 Y.setSubMatrix(0, 1, DMF.copyArray(new double[][] { y })); 213 return Y.transpose(); 214 } 215 216 public static void main(String[] args) throws InterruptedException { 217 final OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1); 218 FImage img = new FImage(200, 200); 219 img = imagePredict(bp, img); 220 final ColourMap m = ColourMap.Hot; 221 222 DisplayUtilities.displayName(m.apply(img), "xor"); 223 final int npixels = img.width * img.height; 224 final int half = img.width / 2; 225 final int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels); 226 while (true) { 227 // for (int i = 0; i < pixels.length; i++) { 228 // int pixel = pixels[i]; 229 // int y = pixel / img.width; 230 // int x = pixel - (y * img.width); 231 // bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new 232 // double[]{xorValue(half,x,y)}); 233 // // Thread.sleep(5); 234 // } 235 bp.update(new double[] { 0, 0 }, new double[] { 0 }); 236 bp.update(new double[] { 1, 1 }, new double[] { 0 }); 237 bp.update(new double[] { 0, 1 }, new double[] { 1 }); 238 bp.update(new double[] { 1, 0 }, new double[] { 1 }); 239 imagePredict(bp, img); 240 DisplayUtilities.displayName(m.apply(img), "xor"); 241 } 242 } 243 244 private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) { 245 final double[] pos = new double[2]; 246 final int half = img.width / 2; 247 for (int y = 0; y < img.height; y++) { 248 for (int x = 0; x < img.width; x++) { 249 pos[0] = x < half ? 0 : 1; 250 pos[1] = y < half ? 0 : 1; 251 final float ret = (float) bp.predict(pos).getElement(0, 0); 252 img.pixels[y][x] = ret; 253 } 254 } 255 return img; 256 } 257}