View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.neuralnet;
31  
32  import gnu.trove.map.hash.TIntIntHashMap;
33  import gnu.trove.map.hash.TIntObjectHashMap;
34  import gnu.trove.procedure.TIntObjectProcedure;
35  import gov.sandia.cognition.io.CSVUtility;
36  
37  import java.io.BufferedReader;
38  import java.io.File;
39  import java.io.FileWriter;
40  import java.io.IOException;
41  import java.io.InputStream;
42  import java.io.InputStreamReader;
43  import java.io.PrintWriter;
44  import java.util.ArrayList;
45  import java.util.Iterator;
46  import java.util.List;
47  
48  import org.encog.engine.network.activation.ActivationStep;
49  import org.encog.mathutil.rbf.RBFEnum;
50  import org.encog.ml.MLRegression;
51  import org.encog.ml.data.MLDataPair;
52  import org.encog.ml.data.MLDataSet;
53  import org.encog.ml.data.specific.CSVNeuralDataSet;
54  import org.encog.ml.svm.SVM;
55  import org.encog.ml.svm.training.SVMTrain;
56  import org.encog.ml.train.MLTrain;
57  import org.encog.neural.cpn.CPN;
58  import org.encog.neural.cpn.training.TrainInstar;
59  import org.encog.neural.cpn.training.TrainOutstar;
60  import org.encog.neural.data.basic.BasicNeuralData;
61  import org.encog.neural.neat.NEATPopulation;
62  import org.encog.neural.neat.training.NEATTraining;
63  import org.encog.neural.networks.training.CalculateScore;
64  import org.encog.neural.networks.training.TrainingSetScore;
65  import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
66  import org.encog.neural.rbf.RBFNetwork;
67  import org.encog.util.simple.EncogUtility;
68  import org.openimaj.util.pair.IndependentPair;
69  
70  import Jama.Matrix;
71  
72  /**
73   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
74   * 
75   *         Just some experiments using the sandia cognitive foundary neural
76   *         nets.
77   * 
78   */
79  public class HandWritingNeuralNetENCOG {
80  
81  	/**
82  	 * Default location of inputs
83  	 */
84  	public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputouput.csv";
85  
86  	private int maxTests = 10;
87  
88  	private TIntIntHashMap examples;
89  	private TIntObjectHashMap<List<IndependentPair<double[], double[]>>> tests;
90  
91  	private int totalTests = 0;
92  
93  	private MLRegression network;
94  
95  	private MLDataSet training;
96  
97  	/**
98  	 * @throws IOException
99  	 *             Load X input and y output from {@link #INPUT_LOCATION}
100 	 */
101 	public HandWritingNeuralNetENCOG() throws IOException {
102 
103 		examples = new TIntIntHashMap();
104 		this.tests = new TIntObjectHashMap<List<IndependentPair<double[], double[]>>>();
105 		prepareDataCollection();
106 		learnNeuralNet();
107 		testNeuralNet();
108 		// new HandWritingInputDisplay(this.training);
109 	}
110 
111 	private void testNeuralNet() {
112 		final double[][] xVals = new double[totalTests][];
113 		final int[] yVals = new int[totalTests];
114 		this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<double[], double[]>>>() {
115 			int done = 0;
116 
117 			@Override
118 			public boolean execute(int number, List<IndependentPair<double[], double[]>> xypairs) {
119 				for (final IndependentPair<double[], double[]> xyval : xypairs) {
120 					final double[] guessed = network.compute(new BasicNeuralData(xyval.firstObject())).getData(); // estimate
121 					int maxIndex = 0;
122 					double maxValue = 0;
123 					for (int i = 0; i < guessed.length; i++) {
124 						if (maxValue < guessed[i])
125 						{
126 							maxValue = guessed[i];
127 							maxIndex = i;
128 						}
129 					}
130 					xVals[done] = xyval.firstObject();
131 					yVals[done] = (maxIndex + 1) % 10;
132 					done++;
133 				}
134 				return true;
135 			}
136 		});
137 		new HandWritingInputDisplay(xVals, yVals);
138 	}
139 
140 	private void prepareDataCollection() throws IOException {
141 		final File tmp = File.createTempFile("data", ".csv");
142 		final InputStream stream = HandWritingNeuralNetENCOG.class.getResourceAsStream(INPUT_LOCATION);
143 		final BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
144 		String line = null;
145 		final PrintWriter writer = new PrintWriter(new FileWriter(tmp));
146 		while ((line = reader.readLine()) != null) {
147 			writer.println(line);
148 		}
149 		writer.close();
150 		reader.close();
151 		training = new CSVNeuralDataSet(tmp.getAbsolutePath(), 400, 10, false);
152 		final Iterator<MLDataPair> elementItr = this.training.iterator();
153 		for (; elementItr.hasNext();) {
154 			final MLDataPair type = elementItr.next();
155 			final double[] yData = type.getIdealArray();
156 			final double[] xData = type.getInputArray();
157 			int yIndex = 0;
158 			while (yData[yIndex] != 1)
159 				yIndex++;
160 			final int currentCount = this.examples.adjustOrPutValue(yIndex, 1, 1);
161 			if (currentCount < this.maxTests) {
162 
163 				List<IndependentPair<double[], double[]>> numberTest = this.tests.get(yIndex);
164 				if (numberTest == null) {
165 					this.tests.put(yIndex, numberTest = new ArrayList<IndependentPair<double[], double[]>>());
166 				}
167 				numberTest.add(IndependentPair.pair(xData, yData));
168 				totalTests++;
169 			}
170 		}
171 
172 	}
173 
174 	private void learnNeuralNet() {
175 		// this.network = EncogUtility.simpleFeedForward(400, 100, 0, 10,
176 		// false);
177 		// MLTrain train = new Backpropagation(this.network, this.training);
178 		// MLTrain train = new ResilientPropagation(this.network,
179 		// this.training);
180 
181 		// this.network = withNEAT();
182 		// this.network = withRBF();
183 		// this.network = withSVM();
184 		this.network = withResilieant();
185 		// this.network = withCPN();
186 	}
187 
188 	private MLRegression withNEAT() {
189 		final NEATPopulation pop = new NEATPopulation(400, 10, 1000);
190 		final CalculateScore score = new TrainingSetScore(this.training);
191 		// train the neural network
192 		final ActivationStep step = new ActivationStep();
193 		step.setCenter(0.5);
194 		pop.setOutputActivationFunction(step);
195 		final MLTrain train = new NEATTraining(score, pop);
196 		EncogUtility.trainToError(train, 0.01515);
197 		return (MLRegression) train.getMethod();
198 	}
199 
200 	private MLRegression withResilieant() {
201 		final MLTrain train = new ResilientPropagation(EncogUtility.simpleFeedForward(400, 100, 0, 10, false),
202 				this.training);
203 		EncogUtility.trainToError(train, 0.01515);
204 		return (MLRegression) train.getMethod();
205 	}
206 
207 	private MLRegression withSVM() {
208 		final MLTrain train = new SVMTrain(new SVM(400, true), this.training);
209 		EncogUtility.trainToError(train, 0.01515);
210 		return (MLRegression) train.getMethod();
211 	}
212 
213 	private MLRegression withRBF() {
214 		final MLRegression train = new RBFNetwork(400, 20, 10, RBFEnum.Gaussian);
215 		EncogUtility.trainToError(train, this.training, 0.01515);
216 		return train;
217 	}
218 
219 	private MLRegression withCPN() {
220 		final CPN result = new CPN(400, 1000, 10, 1);
221 		final MLTrain trainInstar = new TrainInstar(result, training, 0.1, false);
222 		EncogUtility.trainToError(trainInstar, 0.01515);
223 		final MLTrain trainOutstar = new TrainOutstar(result, training, 0.1);
224 		EncogUtility.trainToError(trainOutstar, 0.01515);
225 		return result;
226 	}
227 
228 	private static <T> ArrayList<T> toArrayList(T[] values) {
229 		final ArrayList<T> configList = new ArrayList<T>();
230 		for (final T t : values) {
231 			configList.add(t);
232 		}
233 		return configList;
234 	}
235 
236 	private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException {
237 
238 		String[] lineValues = null;
239 		double[][] outArr = null;
240 		Matrix retMat = null;
241 		int row = 0;
242 		while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) {
243 			if (outArr == null) {
244 				retMat = new Matrix(nLines, lineValues.length);
245 				outArr = retMat.getArray();
246 			}
247 
248 			for (int col = 0; col < lineValues.length; col++) {
249 				outArr[row][col] = Double.parseDouble(lineValues[col]);
250 			}
251 			row++;
252 		}
253 		return retMat;
254 	}
255 
256 	public static void main(String[] args) throws IOException {
257 		new HandWritingNeuralNetENCOG();
258 	}
259 }