001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.neuralnet; 031 032import gnu.trove.map.hash.TIntIntHashMap; 033import gnu.trove.map.hash.TIntObjectHashMap; 034import gnu.trove.procedure.TIntObjectProcedure; 035import gov.sandia.cognition.algorithm.IterativeAlgorithm; 036import gov.sandia.cognition.algorithm.IterativeAlgorithmListener; 037import gov.sandia.cognition.io.CSVUtility; 038import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable; 039import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerLiuStorey; 040import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer; 041import gov.sandia.cognition.learning.data.DefaultInputOutputPair; 042import gov.sandia.cognition.learning.data.InputOutputPair; 043import gov.sandia.cognition.learning.function.scalar.AtanFunction; 044import gov.sandia.cognition.learning.function.vector.ThreeLayerFeedforwardNeuralNetwork; 045import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction; 046import gov.sandia.cognition.math.matrix.Vector; 047import gov.sandia.cognition.math.matrix.VectorEntry; 048import gov.sandia.cognition.math.matrix.VectorFactory; 049import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ; 050 051import java.io.BufferedReader; 052import java.io.IOException; 053import java.io.InputStreamReader; 054import java.util.ArrayList; 055import java.util.List; 056 057import org.openimaj.util.pair.IndependentPair; 058 059import Jama.Matrix; 060 061/** 062 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 063 * 064 * Just some experiments using the sandia cognitive foundary neural 065 * nets. 066 * 067 */ 068public class HandWritingNeuralNetSANDIA implements IterativeAlgorithmListener { 069 070 /** 071 * Default location of inputs 072 */ 073 public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputs.csv"; 074 075 /** 076 * Default location of outputs 077 */ 078 public static final String OUTPUT_LOCATION = "/org/openimaj/ml/handwriting/outputs.csv"; 079 080 private Matrix xVals; 081 082 private Matrix yVals; 083 084 private ArrayList<InputOutputPair<Vector, Vector>> dataCollection; 085 086 private int maxExamples = 400; 087 private int maxTests = 10; 088 private int nHiddenLayer = 20; 089 090 private TIntIntHashMap examples; 091 private TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tests; 092 093 private GradientDescendable neuralNet; 094 095 private int totalTests = 0; 096 097 /** 098 * @throws IOException 099 * Load X input and y output from {@link #INPUT_LOCATION} and 100 * {@link #OUTPUT_LOCATION} 101 */ 102 public HandWritingNeuralNetSANDIA() throws IOException { 103 final BufferedReader xReader = new BufferedReader(new InputStreamReader( 104 HandWritingNeuralNetSANDIA.class.getResourceAsStream(INPUT_LOCATION))); 105 final BufferedReader yReader = new BufferedReader(new InputStreamReader( 106 HandWritingNeuralNetSANDIA.class.getResourceAsStream(OUTPUT_LOCATION))); 107 this.xVals = fromCSV(xReader, 5000); 108 this.yVals = fromCSV(yReader, 5000); 109 110 examples = new TIntIntHashMap(); 111 this.tests = new TIntObjectHashMap<List<IndependentPair<Vector, Vector>>>(); 112 prepareDataCollection(); 113 learnNeuralNet(); 114 testNeuralNet(); 115 // new HandWritingInputDisplay(xVals); 116 } 117 118 private void testNeuralNet() { 119 final double[][] xVals = new double[totalTests][]; 120 final int[] yVals = new int[totalTests]; 121 this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<Vector, Vector>>>() { 122 int done = 0; 123 DenseVectorFactoryMTJ fact = new DenseVectorFactoryMTJ(); 124 125 @Override 126 public boolean execute(int number, List<IndependentPair<Vector, Vector>> xypairs) { 127 for (final IndependentPair<Vector, Vector> xyval : xypairs) { 128 final Vector guessed = neuralNet.evaluate(xyval.firstObject()); 129 int maxIndex = 0; 130 double maxValue = 0; 131 for (final VectorEntry vectorEntry : guessed) { 132 if (maxValue < vectorEntry.getValue()) 133 { 134 maxValue = vectorEntry.getValue(); 135 maxIndex = vectorEntry.getIndex(); 136 } 137 } 138 xVals[done] = fact.copyVector(xyval.firstObject()).getArray(); 139 yVals[done] = maxIndex; 140 done++; 141 } 142 return true; 143 } 144 }); 145 new HandWritingInputDisplay(xVals, yVals); 146 } 147 148 private void prepareDataCollection() { 149 this.dataCollection = new ArrayList<InputOutputPair<Vector, Vector>>(); 150 final double[][] xArr = this.xVals.getArray(); 151 final double[][] yArr = this.yVals.getArray(); 152 153 for (int i = 0; i < xArr.length; i++) { 154 final Vector xVector = VectorFactory.getDefault().copyArray(xArr[i]); 155 final double[] yValues = new double[10]; 156 final int number = (int) (yArr[i][0] % 10); 157 final int count = examples.adjustOrPutValue(number, 1, 1); 158 yValues[number] = 1; 159 final Vector yVector = VectorFactory.getDefault().copyValues(yValues); 160 if (this.maxExamples != -1 && count > maxExamples) { 161 if (count > maxTests + maxExamples) { 162 continue; 163 } 164 List<IndependentPair<Vector, Vector>> numberTest = this.tests.get(number); 165 if (numberTest == null) { 166 this.tests.put(number, numberTest = new ArrayList<IndependentPair<Vector, Vector>>()); 167 } 168 numberTest.add(IndependentPair.pair(xVector, yVector)); 169 totalTests++; 170 } 171 else { 172 this.dataCollection.add(DefaultInputOutputPair.create(xVector, yVector)); 173 } 174 175 } 176 } 177 178 private void learnNeuralNet() { 179 // ArrayList<Integer> nodesPerLayer = toArrayList( 180 // new 181 // Integer[]{this.xVals.getColumnDimension(),this.xVals.getColumnDimension()/4,10} 182 // ); 183 // ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions = 184 // toArrayList( 185 // new DifferentiableUnivariateScalarFunction[]{new 186 // SigmoidFunction(),new SigmoidFunction()} 187 // ); 188 final ArrayList<Integer> nodesPerLayer = toArrayList( 189 new Integer[] { this.xVals.getColumnDimension(), nHiddenLayer, 10 } 190 ); 191 final ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions = toArrayList( 192 new DifferentiableUnivariateScalarFunction[] { new AtanFunction(), new AtanFunction() } 193 ); 194 // DifferentiableFeedforwardNeuralNetwork nn = new 195 // DifferentiableFeedforwardNeuralNetwork( 196 // nodesPerLayer, 197 // squashFunctions, 198 // new Random() 199 // ); 200 final ThreeLayerFeedforwardNeuralNetwork nn = new ThreeLayerFeedforwardNeuralNetwork( 201 this.xVals.getColumnDimension(), nHiddenLayer, 10); 202 final ParameterDifferentiableCostMinimizer conjugateGradient = new ParameterDifferentiableCostMinimizer( 203 new FunctionMinimizerLiuStorey()); 204 conjugateGradient.setObjectToOptimize(nn); 205 // conjugateGradient.setCostFunction( new MeanSquaredErrorCostFunction() 206 // ); 207 conjugateGradient.addIterativeAlgorithmListener(this); 208 conjugateGradient.setMaxIterations(50); 209 // FletcherXuHybridEstimation minimiser = new 210 // FletcherXuHybridEstimation(); 211 // minimiser.setObjectToOptimize( nn ); 212 // minimiser.setMaxIterations(50); 213 neuralNet = conjugateGradient.learn(this.dataCollection); 214 } 215 216 private static <T> ArrayList<T> toArrayList(T[] values) { 217 final ArrayList<T> configList = new ArrayList<T>(); 218 for (final T t : values) { 219 configList.add(t); 220 } 221 return configList; 222 } 223 224 private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException { 225 226 String[] lineValues = null; 227 double[][] outArr = null; 228 Matrix retMat = null; 229 int row = 0; 230 while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) { 231 if (outArr == null) { 232 retMat = new Matrix(nLines, lineValues.length); 233 outArr = retMat.getArray(); 234 } 235 236 for (int col = 0; col < lineValues.length; col++) { 237 outArr[row][col] = Double.parseDouble(lineValues[col]); 238 } 239 row++; 240 } 241 return retMat; 242 } 243 244 public static void main(String[] args) throws IOException { 245 new HandWritingNeuralNetSANDIA(); 246 } 247 248 @Override 249 public void algorithmStarted(IterativeAlgorithm algorithm) { 250 System.out.println("Learning neural network"); 251 } 252 253 @Override 254 public void algorithmEnded(IterativeAlgorithm algorithm) { 255 System.out.println("Done Learning!"); 256 } 257 258 @Override 259 public void stepStarted(IterativeAlgorithm algorithm) { 260 System.out.println("... starting step: " + algorithm.getIteration()); 261 } 262 263 @Override 264 public void stepEnded(IterativeAlgorithm algorithm) { 265 System.out.println("... ending step: " + algorithm.getIteration()); 266 } 267}