001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.neuralnet; 031 032import gnu.trove.map.hash.TIntIntHashMap; 033import gnu.trove.map.hash.TIntObjectHashMap; 034import gnu.trove.procedure.TIntObjectProcedure; 035import gov.sandia.cognition.io.CSVUtility; 036 037import java.io.BufferedReader; 038import java.io.File; 039import java.io.FileWriter; 040import java.io.IOException; 041import java.io.InputStream; 042import java.io.InputStreamReader; 043import java.io.PrintWriter; 044import java.util.ArrayList; 045import java.util.Iterator; 046import java.util.List; 047 048import org.encog.engine.network.activation.ActivationStep; 049import org.encog.mathutil.rbf.RBFEnum; 050import org.encog.ml.MLRegression; 051import org.encog.ml.data.MLDataPair; 052import org.encog.ml.data.MLDataSet; 053import org.encog.ml.data.specific.CSVNeuralDataSet; 054import org.encog.ml.svm.SVM; 055import org.encog.ml.svm.training.SVMTrain; 056import org.encog.ml.train.MLTrain; 057import org.encog.neural.cpn.CPN; 058import org.encog.neural.cpn.training.TrainInstar; 059import org.encog.neural.cpn.training.TrainOutstar; 060import org.encog.neural.data.basic.BasicNeuralData; 061import org.encog.neural.neat.NEATPopulation; 062import org.encog.neural.neat.training.NEATTraining; 063import org.encog.neural.networks.training.CalculateScore; 064import org.encog.neural.networks.training.TrainingSetScore; 065import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation; 066import org.encog.neural.rbf.RBFNetwork; 067import org.encog.util.simple.EncogUtility; 068import org.openimaj.util.pair.IndependentPair; 069 070import Jama.Matrix; 071 072/** 073 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 074 * 075 * Just some experiments using the sandia cognitive foundary neural 076 * nets. 077 * 078 */ 079public class HandWritingNeuralNetENCOG { 080 081 /** 082 * Default location of inputs 083 */ 084 public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputouput.csv"; 085 086 private int maxTests = 10; 087 088 private TIntIntHashMap examples; 089 private TIntObjectHashMap<List<IndependentPair<double[], double[]>>> tests; 090 091 private int totalTests = 0; 092 093 private MLRegression network; 094 095 private MLDataSet training; 096 097 /** 098 * @throws IOException 099 * Load X input and y output from {@link #INPUT_LOCATION} 100 */ 101 public HandWritingNeuralNetENCOG() throws IOException { 102 103 examples = new TIntIntHashMap(); 104 this.tests = new TIntObjectHashMap<List<IndependentPair<double[], double[]>>>(); 105 prepareDataCollection(); 106 learnNeuralNet(); 107 testNeuralNet(); 108 // new HandWritingInputDisplay(this.training); 109 } 110 111 private void testNeuralNet() { 112 final double[][] xVals = new double[totalTests][]; 113 final int[] yVals = new int[totalTests]; 114 this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<double[], double[]>>>() { 115 int done = 0; 116 117 @Override 118 public boolean execute(int number, List<IndependentPair<double[], double[]>> xypairs) { 119 for (final IndependentPair<double[], double[]> xyval : xypairs) { 120 final double[] guessed = network.compute(new BasicNeuralData(xyval.firstObject())).getData(); // estimate 121 int maxIndex = 0; 122 double maxValue = 0; 123 for (int i = 0; i < guessed.length; i++) { 124 if (maxValue < guessed[i]) 125 { 126 maxValue = guessed[i]; 127 maxIndex = i; 128 } 129 } 130 xVals[done] = xyval.firstObject(); 131 yVals[done] = (maxIndex + 1) % 10; 132 done++; 133 } 134 return true; 135 } 136 }); 137 new HandWritingInputDisplay(xVals, yVals); 138 } 139 140 private void prepareDataCollection() throws IOException { 141 final File tmp = File.createTempFile("data", ".csv"); 142 final InputStream stream = HandWritingNeuralNetENCOG.class.getResourceAsStream(INPUT_LOCATION); 143 final BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); 144 String line = null; 145 final PrintWriter writer = new PrintWriter(new FileWriter(tmp)); 146 while ((line = reader.readLine()) != null) { 147 writer.println(line); 148 } 149 writer.close(); 150 reader.close(); 151 training = new CSVNeuralDataSet(tmp.getAbsolutePath(), 400, 10, false); 152 final Iterator<MLDataPair> elementItr = this.training.iterator(); 153 for (; elementItr.hasNext();) { 154 final MLDataPair type = elementItr.next(); 155 final double[] yData = type.getIdealArray(); 156 final double[] xData = type.getInputArray(); 157 int yIndex = 0; 158 while (yData[yIndex] != 1) 159 yIndex++; 160 final int currentCount = this.examples.adjustOrPutValue(yIndex, 1, 1); 161 if (currentCount < this.maxTests) { 162 163 List<IndependentPair<double[], double[]>> numberTest = this.tests.get(yIndex); 164 if (numberTest == null) { 165 this.tests.put(yIndex, numberTest = new ArrayList<IndependentPair<double[], double[]>>()); 166 } 167 numberTest.add(IndependentPair.pair(xData, yData)); 168 totalTests++; 169 } 170 } 171 172 } 173 174 private void learnNeuralNet() { 175 // this.network = EncogUtility.simpleFeedForward(400, 100, 0, 10, 176 // false); 177 // MLTrain train = new Backpropagation(this.network, this.training); 178 // MLTrain train = new ResilientPropagation(this.network, 179 // this.training); 180 181 // this.network = withNEAT(); 182 // this.network = withRBF(); 183 // this.network = withSVM(); 184 this.network = withResilieant(); 185 // this.network = withCPN(); 186 } 187 188 private MLRegression withNEAT() { 189 final NEATPopulation pop = new NEATPopulation(400, 10, 1000); 190 final CalculateScore score = new TrainingSetScore(this.training); 191 // train the neural network 192 final ActivationStep step = new ActivationStep(); 193 step.setCenter(0.5); 194 pop.setOutputActivationFunction(step); 195 final MLTrain train = new NEATTraining(score, pop); 196 EncogUtility.trainToError(train, 0.01515); 197 return (MLRegression) train.getMethod(); 198 } 199 200 private MLRegression withResilieant() { 201 final MLTrain train = new ResilientPropagation(EncogUtility.simpleFeedForward(400, 100, 0, 10, false), 202 this.training); 203 EncogUtility.trainToError(train, 0.01515); 204 return (MLRegression) train.getMethod(); 205 } 206 207 private MLRegression withSVM() { 208 final MLTrain train = new SVMTrain(new SVM(400, true), this.training); 209 EncogUtility.trainToError(train, 0.01515); 210 return (MLRegression) train.getMethod(); 211 } 212 213 private MLRegression withRBF() { 214 final MLRegression train = new RBFNetwork(400, 20, 10, RBFEnum.Gaussian); 215 EncogUtility.trainToError(train, this.training, 0.01515); 216 return train; 217 } 218 219 private MLRegression withCPN() { 220 final CPN result = new CPN(400, 1000, 10, 1); 221 final MLTrain trainInstar = new TrainInstar(result, training, 0.1, false); 222 EncogUtility.trainToError(trainInstar, 0.01515); 223 final MLTrain trainOutstar = new TrainOutstar(result, training, 0.1); 224 EncogUtility.trainToError(trainOutstar, 0.01515); 225 return result; 226 } 227 228 private static <T> ArrayList<T> toArrayList(T[] values) { 229 final ArrayList<T> configList = new ArrayList<T>(); 230 for (final T t : values) { 231 configList.add(t); 232 } 233 return configList; 234 } 235 236 private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException { 237 238 String[] lineValues = null; 239 double[][] outArr = null; 240 Matrix retMat = null; 241 int row = 0; 242 while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) { 243 if (outArr == null) { 244 retMat = new Matrix(nLines, lineValues.length); 245 outArr = retMat.getArray(); 246 } 247 248 for (int col = 0; col < lineValues.length; col++) { 249 outArr[row][col] = Double.parseDouble(lineValues[col]); 250 } 251 row++; 252 } 253 return retMat; 254 } 255 256 public static void main(String[] args) throws IOException { 257 new HandWritingNeuralNetENCOG(); 258 } 259}