View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.ml.neuralnet;
31  
32  import gnu.trove.map.hash.TIntIntHashMap;
33  import gnu.trove.map.hash.TIntObjectHashMap;
34  import gnu.trove.procedure.TIntObjectProcedure;
35  import gov.sandia.cognition.algorithm.IterativeAlgorithm;
36  import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
37  import gov.sandia.cognition.io.CSVUtility;
38  import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
39  import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerLiuStorey;
40  import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
41  import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
42  import gov.sandia.cognition.learning.data.InputOutputPair;
43  import gov.sandia.cognition.learning.function.scalar.AtanFunction;
44  import gov.sandia.cognition.learning.function.vector.ThreeLayerFeedforwardNeuralNetwork;
45  import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
46  import gov.sandia.cognition.math.matrix.Vector;
47  import gov.sandia.cognition.math.matrix.VectorEntry;
48  import gov.sandia.cognition.math.matrix.VectorFactory;
49  import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
50  
51  import java.io.BufferedReader;
52  import java.io.IOException;
53  import java.io.InputStreamReader;
54  import java.util.ArrayList;
55  import java.util.List;
56  
57  import org.openimaj.util.pair.IndependentPair;
58  
59  import Jama.Matrix;
60  
61  /**
62   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
63   * 
64   *         Just some experiments using the sandia cognitive foundary neural
65   *         nets.
66   * 
67   */
68  public class HandWritingNeuralNetSANDIA implements IterativeAlgorithmListener {
69  
70  	/**
71  	 * Default location of inputs
72  	 */
73  	public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputs.csv";
74  
75  	/**
76  	 * Default location of outputs
77  	 */
78  	public static final String OUTPUT_LOCATION = "/org/openimaj/ml/handwriting/outputs.csv";
79  
80  	private Matrix xVals;
81  
82  	private Matrix yVals;
83  
84  	private ArrayList<InputOutputPair<Vector, Vector>> dataCollection;
85  
86  	private int maxExamples = 400;
87  	private int maxTests = 10;
88  	private int nHiddenLayer = 20;
89  
90  	private TIntIntHashMap examples;
91  	private TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tests;
92  
93  	private GradientDescendable neuralNet;
94  
95  	private int totalTests = 0;
96  
97  	/**
98  	 * @throws IOException
99  	 *             Load X input and y output from {@link #INPUT_LOCATION} and
100 	 *             {@link #OUTPUT_LOCATION}
101 	 */
102 	public HandWritingNeuralNetSANDIA() throws IOException {
103 		final BufferedReader xReader = new BufferedReader(new InputStreamReader(
104 				HandWritingNeuralNetSANDIA.class.getResourceAsStream(INPUT_LOCATION)));
105 		final BufferedReader yReader = new BufferedReader(new InputStreamReader(
106 				HandWritingNeuralNetSANDIA.class.getResourceAsStream(OUTPUT_LOCATION)));
107 		this.xVals = fromCSV(xReader, 5000);
108 		this.yVals = fromCSV(yReader, 5000);
109 
110 		examples = new TIntIntHashMap();
111 		this.tests = new TIntObjectHashMap<List<IndependentPair<Vector, Vector>>>();
112 		prepareDataCollection();
113 		learnNeuralNet();
114 		testNeuralNet();
115 		// new HandWritingInputDisplay(xVals);
116 	}
117 
118 	private void testNeuralNet() {
119 		final double[][] xVals = new double[totalTests][];
120 		final int[] yVals = new int[totalTests];
121 		this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<Vector, Vector>>>() {
122 			int done = 0;
123 			DenseVectorFactoryMTJ fact = new DenseVectorFactoryMTJ();
124 
125 			@Override
126 			public boolean execute(int number, List<IndependentPair<Vector, Vector>> xypairs) {
127 				for (final IndependentPair<Vector, Vector> xyval : xypairs) {
128 					final Vector guessed = neuralNet.evaluate(xyval.firstObject());
129 					int maxIndex = 0;
130 					double maxValue = 0;
131 					for (final VectorEntry vectorEntry : guessed) {
132 						if (maxValue < vectorEntry.getValue())
133 						{
134 							maxValue = vectorEntry.getValue();
135 							maxIndex = vectorEntry.getIndex();
136 						}
137 					}
138 					xVals[done] = fact.copyVector(xyval.firstObject()).getArray();
139 					yVals[done] = maxIndex;
140 					done++;
141 				}
142 				return true;
143 			}
144 		});
145 		new HandWritingInputDisplay(xVals, yVals);
146 	}
147 
148 	private void prepareDataCollection() {
149 		this.dataCollection = new ArrayList<InputOutputPair<Vector, Vector>>();
150 		final double[][] xArr = this.xVals.getArray();
151 		final double[][] yArr = this.yVals.getArray();
152 
153 		for (int i = 0; i < xArr.length; i++) {
154 			final Vector xVector = VectorFactory.getDefault().copyArray(xArr[i]);
155 			final double[] yValues = new double[10];
156 			final int number = (int) (yArr[i][0] % 10);
157 			final int count = examples.adjustOrPutValue(number, 1, 1);
158 			yValues[number] = 1;
159 			final Vector yVector = VectorFactory.getDefault().copyValues(yValues);
160 			if (this.maxExamples != -1 && count > maxExamples) {
161 				if (count > maxTests + maxExamples) {
162 					continue;
163 				}
164 				List<IndependentPair<Vector, Vector>> numberTest = this.tests.get(number);
165 				if (numberTest == null) {
166 					this.tests.put(number, numberTest = new ArrayList<IndependentPair<Vector, Vector>>());
167 				}
168 				numberTest.add(IndependentPair.pair(xVector, yVector));
169 				totalTests++;
170 			}
171 			else {
172 				this.dataCollection.add(DefaultInputOutputPair.create(xVector, yVector));
173 			}
174 
175 		}
176 	}
177 
178 	private void learnNeuralNet() {
179 		// ArrayList<Integer> nodesPerLayer = toArrayList(
180 		// new
181 		// Integer[]{this.xVals.getColumnDimension(),this.xVals.getColumnDimension()/4,10}
182 		// );
183 		// ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions =
184 		// toArrayList(
185 		// new DifferentiableUnivariateScalarFunction[]{new
186 		// SigmoidFunction(),new SigmoidFunction()}
187 		// );
188 		final ArrayList<Integer> nodesPerLayer = toArrayList(
189 				new Integer[] { this.xVals.getColumnDimension(), nHiddenLayer, 10 }
190 				);
191 		final ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions = toArrayList(
192 				new DifferentiableUnivariateScalarFunction[] { new AtanFunction(), new AtanFunction() }
193 				);
194 		// DifferentiableFeedforwardNeuralNetwork nn = new
195 		// DifferentiableFeedforwardNeuralNetwork(
196 		// nodesPerLayer,
197 		// squashFunctions,
198 		// new Random()
199 		// );
200 		final ThreeLayerFeedforwardNeuralNetwork nn = new ThreeLayerFeedforwardNeuralNetwork(
201 				this.xVals.getColumnDimension(), nHiddenLayer, 10);
202 		final ParameterDifferentiableCostMinimizer conjugateGradient = new ParameterDifferentiableCostMinimizer(
203 				new FunctionMinimizerLiuStorey());
204 		conjugateGradient.setObjectToOptimize(nn);
205 		// conjugateGradient.setCostFunction( new MeanSquaredErrorCostFunction()
206 		// );
207 		conjugateGradient.addIterativeAlgorithmListener(this);
208 		conjugateGradient.setMaxIterations(50);
209 		// FletcherXuHybridEstimation minimiser = new
210 		// FletcherXuHybridEstimation();
211 		// minimiser.setObjectToOptimize( nn );
212 		// minimiser.setMaxIterations(50);
213 		neuralNet = conjugateGradient.learn(this.dataCollection);
214 	}
215 
216 	private static <T> ArrayList<T> toArrayList(T[] values) {
217 		final ArrayList<T> configList = new ArrayList<T>();
218 		for (final T t : values) {
219 			configList.add(t);
220 		}
221 		return configList;
222 	}
223 
224 	private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException {
225 
226 		String[] lineValues = null;
227 		double[][] outArr = null;
228 		Matrix retMat = null;
229 		int row = 0;
230 		while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) {
231 			if (outArr == null) {
232 				retMat = new Matrix(nLines, lineValues.length);
233 				outArr = retMat.getArray();
234 			}
235 
236 			for (int col = 0; col < lineValues.length; col++) {
237 				outArr[row][col] = Double.parseDouble(lineValues[col]);
238 			}
239 			row++;
240 		}
241 		return retMat;
242 	}
243 
244 	public static void main(String[] args) throws IOException {
245 		new HandWritingNeuralNetSANDIA();
246 	}
247 
248 	@Override
249 	public void algorithmStarted(IterativeAlgorithm algorithm) {
250 		System.out.println("Learning neural network");
251 	}
252 
253 	@Override
254 	public void algorithmEnded(IterativeAlgorithm algorithm) {
255 		System.out.println("Done Learning!");
256 	}
257 
258 	@Override
259 	public void stepStarted(IterativeAlgorithm algorithm) {
260 		System.out.println("... starting step: " + algorithm.getIteration());
261 	}
262 
263 	@Override
264 	public void stepEnded(IterativeAlgorithm algorithm) {
265 		System.out.println("... ending step: " + algorithm.getIteration());
266 	}
267 }