1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.neuralnet;
31
32 import gnu.trove.map.hash.TIntIntHashMap;
33 import gnu.trove.map.hash.TIntObjectHashMap;
34 import gnu.trove.procedure.TIntObjectProcedure;
35 import gov.sandia.cognition.algorithm.IterativeAlgorithm;
36 import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
37 import gov.sandia.cognition.io.CSVUtility;
38 import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
39 import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerLiuStorey;
40 import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
41 import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
42 import gov.sandia.cognition.learning.data.InputOutputPair;
43 import gov.sandia.cognition.learning.function.scalar.AtanFunction;
44 import gov.sandia.cognition.learning.function.vector.ThreeLayerFeedforwardNeuralNetwork;
45 import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
46 import gov.sandia.cognition.math.matrix.Vector;
47 import gov.sandia.cognition.math.matrix.VectorEntry;
48 import gov.sandia.cognition.math.matrix.VectorFactory;
49 import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
50
51 import java.io.BufferedReader;
52 import java.io.IOException;
53 import java.io.InputStreamReader;
54 import java.util.ArrayList;
55 import java.util.List;
56
57 import org.openimaj.util.pair.IndependentPair;
58
59 import Jama.Matrix;
60
61
62
63
64
65
66
67
68 public class HandWritingNeuralNetSANDIA implements IterativeAlgorithmListener {
69
70
71
72
73 public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputs.csv";
74
75
76
77
78 public static final String OUTPUT_LOCATION = "/org/openimaj/ml/handwriting/outputs.csv";
79
80 private Matrix xVals;
81
82 private Matrix yVals;
83
84 private ArrayList<InputOutputPair<Vector, Vector>> dataCollection;
85
86 private int maxExamples = 400;
87 private int maxTests = 10;
88 private int nHiddenLayer = 20;
89
90 private TIntIntHashMap examples;
91 private TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tests;
92
93 private GradientDescendable neuralNet;
94
95 private int totalTests = 0;
96
97
98
99
100
101
102 public HandWritingNeuralNetSANDIA() throws IOException {
103 final BufferedReader xReader = new BufferedReader(new InputStreamReader(
104 HandWritingNeuralNetSANDIA.class.getResourceAsStream(INPUT_LOCATION)));
105 final BufferedReader yReader = new BufferedReader(new InputStreamReader(
106 HandWritingNeuralNetSANDIA.class.getResourceAsStream(OUTPUT_LOCATION)));
107 this.xVals = fromCSV(xReader, 5000);
108 this.yVals = fromCSV(yReader, 5000);
109
110 examples = new TIntIntHashMap();
111 this.tests = new TIntObjectHashMap<List<IndependentPair<Vector, Vector>>>();
112 prepareDataCollection();
113 learnNeuralNet();
114 testNeuralNet();
115
116 }
117
118 private void testNeuralNet() {
119 final double[][] xVals = new double[totalTests][];
120 final int[] yVals = new int[totalTests];
121 this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<Vector, Vector>>>() {
122 int done = 0;
123 DenseVectorFactoryMTJ fact = new DenseVectorFactoryMTJ();
124
125 @Override
126 public boolean execute(int number, List<IndependentPair<Vector, Vector>> xypairs) {
127 for (final IndependentPair<Vector, Vector> xyval : xypairs) {
128 final Vector guessed = neuralNet.evaluate(xyval.firstObject());
129 int maxIndex = 0;
130 double maxValue = 0;
131 for (final VectorEntry vectorEntry : guessed) {
132 if (maxValue < vectorEntry.getValue())
133 {
134 maxValue = vectorEntry.getValue();
135 maxIndex = vectorEntry.getIndex();
136 }
137 }
138 xVals[done] = fact.copyVector(xyval.firstObject()).getArray();
139 yVals[done] = maxIndex;
140 done++;
141 }
142 return true;
143 }
144 });
145 new HandWritingInputDisplay(xVals, yVals);
146 }
147
148 private void prepareDataCollection() {
149 this.dataCollection = new ArrayList<InputOutputPair<Vector, Vector>>();
150 final double[][] xArr = this.xVals.getArray();
151 final double[][] yArr = this.yVals.getArray();
152
153 for (int i = 0; i < xArr.length; i++) {
154 final Vector xVector = VectorFactory.getDefault().copyArray(xArr[i]);
155 final double[] yValues = new double[10];
156 final int number = (int) (yArr[i][0] % 10);
157 final int count = examples.adjustOrPutValue(number, 1, 1);
158 yValues[number] = 1;
159 final Vector yVector = VectorFactory.getDefault().copyValues(yValues);
160 if (this.maxExamples != -1 && count > maxExamples) {
161 if (count > maxTests + maxExamples) {
162 continue;
163 }
164 List<IndependentPair<Vector, Vector>> numberTest = this.tests.get(number);
165 if (numberTest == null) {
166 this.tests.put(number, numberTest = new ArrayList<IndependentPair<Vector, Vector>>());
167 }
168 numberTest.add(IndependentPair.pair(xVector, yVector));
169 totalTests++;
170 }
171 else {
172 this.dataCollection.add(DefaultInputOutputPair.create(xVector, yVector));
173 }
174
175 }
176 }
177
178 private void learnNeuralNet() {
179
180
181
182
183
184
185
186
187
188 final ArrayList<Integer> nodesPerLayer = toArrayList(
189 new Integer[] { this.xVals.getColumnDimension(), nHiddenLayer, 10 }
190 );
191 final ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions = toArrayList(
192 new DifferentiableUnivariateScalarFunction[] { new AtanFunction(), new AtanFunction() }
193 );
194
195
196
197
198
199
200 final ThreeLayerFeedforwardNeuralNetwork nn = new ThreeLayerFeedforwardNeuralNetwork(
201 this.xVals.getColumnDimension(), nHiddenLayer, 10);
202 final ParameterDifferentiableCostMinimizer conjugateGradient = new ParameterDifferentiableCostMinimizer(
203 new FunctionMinimizerLiuStorey());
204 conjugateGradient.setObjectToOptimize(nn);
205
206
207 conjugateGradient.addIterativeAlgorithmListener(this);
208 conjugateGradient.setMaxIterations(50);
209
210
211
212
213 neuralNet = conjugateGradient.learn(this.dataCollection);
214 }
215
216 private static <T> ArrayList<T> toArrayList(T[] values) {
217 final ArrayList<T> configList = new ArrayList<T>();
218 for (final T t : values) {
219 configList.add(t);
220 }
221 return configList;
222 }
223
224 private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException {
225
226 String[] lineValues = null;
227 double[][] outArr = null;
228 Matrix retMat = null;
229 int row = 0;
230 while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) {
231 if (outArr == null) {
232 retMat = new Matrix(nLines, lineValues.length);
233 outArr = retMat.getArray();
234 }
235
236 for (int col = 0; col < lineValues.length; col++) {
237 outArr[row][col] = Double.parseDouble(lineValues[col]);
238 }
239 row++;
240 }
241 return retMat;
242 }
243
244 public static void main(String[] args) throws IOException {
245 new HandWritingNeuralNetSANDIA();
246 }
247
248 @Override
249 public void algorithmStarted(IterativeAlgorithm algorithm) {
250 System.out.println("Learning neural network");
251 }
252
253 @Override
254 public void algorithmEnded(IterativeAlgorithm algorithm) {
255 System.out.println("Done Learning!");
256 }
257
258 @Override
259 public void stepStarted(IterativeAlgorithm algorithm) {
260 System.out.println("... starting step: " + algorithm.getIteration());
261 }
262
263 @Override
264 public void stepEnded(IterativeAlgorithm algorithm) {
265 System.out.println("... ending step: " + algorithm.getIteration());
266 }
267 }