001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.neuralnet;
031
032import gnu.trove.map.hash.TIntIntHashMap;
033import gnu.trove.map.hash.TIntObjectHashMap;
034import gnu.trove.procedure.TIntObjectProcedure;
035import gov.sandia.cognition.algorithm.IterativeAlgorithm;
036import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
037import gov.sandia.cognition.io.CSVUtility;
038import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
039import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerLiuStorey;
040import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
041import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
042import gov.sandia.cognition.learning.data.InputOutputPair;
043import gov.sandia.cognition.learning.function.scalar.AtanFunction;
044import gov.sandia.cognition.learning.function.vector.ThreeLayerFeedforwardNeuralNetwork;
045import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
046import gov.sandia.cognition.math.matrix.Vector;
047import gov.sandia.cognition.math.matrix.VectorEntry;
048import gov.sandia.cognition.math.matrix.VectorFactory;
049import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
050
051import java.io.BufferedReader;
052import java.io.IOException;
053import java.io.InputStreamReader;
054import java.util.ArrayList;
055import java.util.List;
056
057import org.openimaj.util.pair.IndependentPair;
058
059import Jama.Matrix;
060
061/**
062 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
063 * 
064 *         Just some experiments using the sandia cognitive foundary neural
065 *         nets.
066 * 
067 */
068public class HandWritingNeuralNetSANDIA implements IterativeAlgorithmListener {
069
070        /**
071         * Default location of inputs
072         */
073        public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputs.csv";
074
075        /**
076         * Default location of outputs
077         */
078        public static final String OUTPUT_LOCATION = "/org/openimaj/ml/handwriting/outputs.csv";
079
080        private Matrix xVals;
081
082        private Matrix yVals;
083
084        private ArrayList<InputOutputPair<Vector, Vector>> dataCollection;
085
086        private int maxExamples = 400;
087        private int maxTests = 10;
088        private int nHiddenLayer = 20;
089
090        private TIntIntHashMap examples;
091        private TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tests;
092
093        private GradientDescendable neuralNet;
094
095        private int totalTests = 0;
096
097        /**
098         * @throws IOException
099         *             Load X input and y output from {@link #INPUT_LOCATION} and
100         *             {@link #OUTPUT_LOCATION}
101         */
102        public HandWritingNeuralNetSANDIA() throws IOException {
103                final BufferedReader xReader = new BufferedReader(new InputStreamReader(
104                                HandWritingNeuralNetSANDIA.class.getResourceAsStream(INPUT_LOCATION)));
105                final BufferedReader yReader = new BufferedReader(new InputStreamReader(
106                                HandWritingNeuralNetSANDIA.class.getResourceAsStream(OUTPUT_LOCATION)));
107                this.xVals = fromCSV(xReader, 5000);
108                this.yVals = fromCSV(yReader, 5000);
109
110                examples = new TIntIntHashMap();
111                this.tests = new TIntObjectHashMap<List<IndependentPair<Vector, Vector>>>();
112                prepareDataCollection();
113                learnNeuralNet();
114                testNeuralNet();
115                // new HandWritingInputDisplay(xVals);
116        }
117
118        private void testNeuralNet() {
119                final double[][] xVals = new double[totalTests][];
120                final int[] yVals = new int[totalTests];
121                this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<Vector, Vector>>>() {
122                        int done = 0;
123                        DenseVectorFactoryMTJ fact = new DenseVectorFactoryMTJ();
124
125                        @Override
126                        public boolean execute(int number, List<IndependentPair<Vector, Vector>> xypairs) {
127                                for (final IndependentPair<Vector, Vector> xyval : xypairs) {
128                                        final Vector guessed = neuralNet.evaluate(xyval.firstObject());
129                                        int maxIndex = 0;
130                                        double maxValue = 0;
131                                        for (final VectorEntry vectorEntry : guessed) {
132                                                if (maxValue < vectorEntry.getValue())
133                                                {
134                                                        maxValue = vectorEntry.getValue();
135                                                        maxIndex = vectorEntry.getIndex();
136                                                }
137                                        }
138                                        xVals[done] = fact.copyVector(xyval.firstObject()).getArray();
139                                        yVals[done] = maxIndex;
140                                        done++;
141                                }
142                                return true;
143                        }
144                });
145                new HandWritingInputDisplay(xVals, yVals);
146        }
147
148        private void prepareDataCollection() {
149                this.dataCollection = new ArrayList<InputOutputPair<Vector, Vector>>();
150                final double[][] xArr = this.xVals.getArray();
151                final double[][] yArr = this.yVals.getArray();
152
153                for (int i = 0; i < xArr.length; i++) {
154                        final Vector xVector = VectorFactory.getDefault().copyArray(xArr[i]);
155                        final double[] yValues = new double[10];
156                        final int number = (int) (yArr[i][0] % 10);
157                        final int count = examples.adjustOrPutValue(number, 1, 1);
158                        yValues[number] = 1;
159                        final Vector yVector = VectorFactory.getDefault().copyValues(yValues);
160                        if (this.maxExamples != -1 && count > maxExamples) {
161                                if (count > maxTests + maxExamples) {
162                                        continue;
163                                }
164                                List<IndependentPair<Vector, Vector>> numberTest = this.tests.get(number);
165                                if (numberTest == null) {
166                                        this.tests.put(number, numberTest = new ArrayList<IndependentPair<Vector, Vector>>());
167                                }
168                                numberTest.add(IndependentPair.pair(xVector, yVector));
169                                totalTests++;
170                        }
171                        else {
172                                this.dataCollection.add(DefaultInputOutputPair.create(xVector, yVector));
173                        }
174
175                }
176        }
177
178        private void learnNeuralNet() {
179                // ArrayList<Integer> nodesPerLayer = toArrayList(
180                // new
181                // Integer[]{this.xVals.getColumnDimension(),this.xVals.getColumnDimension()/4,10}
182                // );
183                // ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions =
184                // toArrayList(
185                // new DifferentiableUnivariateScalarFunction[]{new
186                // SigmoidFunction(),new SigmoidFunction()}
187                // );
188                final ArrayList<Integer> nodesPerLayer = toArrayList(
189                                new Integer[] { this.xVals.getColumnDimension(), nHiddenLayer, 10 }
190                                );
191                final ArrayList<DifferentiableUnivariateScalarFunction> squashFunctions = toArrayList(
192                                new DifferentiableUnivariateScalarFunction[] { new AtanFunction(), new AtanFunction() }
193                                );
194                // DifferentiableFeedforwardNeuralNetwork nn = new
195                // DifferentiableFeedforwardNeuralNetwork(
196                // nodesPerLayer,
197                // squashFunctions,
198                // new Random()
199                // );
200                final ThreeLayerFeedforwardNeuralNetwork nn = new ThreeLayerFeedforwardNeuralNetwork(
201                                this.xVals.getColumnDimension(), nHiddenLayer, 10);
202                final ParameterDifferentiableCostMinimizer conjugateGradient = new ParameterDifferentiableCostMinimizer(
203                                new FunctionMinimizerLiuStorey());
204                conjugateGradient.setObjectToOptimize(nn);
205                // conjugateGradient.setCostFunction( new MeanSquaredErrorCostFunction()
206                // );
207                conjugateGradient.addIterativeAlgorithmListener(this);
208                conjugateGradient.setMaxIterations(50);
209                // FletcherXuHybridEstimation minimiser = new
210                // FletcherXuHybridEstimation();
211                // minimiser.setObjectToOptimize( nn );
212                // minimiser.setMaxIterations(50);
213                neuralNet = conjugateGradient.learn(this.dataCollection);
214        }
215
216        private static <T> ArrayList<T> toArrayList(T[] values) {
217                final ArrayList<T> configList = new ArrayList<T>();
218                for (final T t : values) {
219                        configList.add(t);
220                }
221                return configList;
222        }
223
224        private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException {
225
226                String[] lineValues = null;
227                double[][] outArr = null;
228                Matrix retMat = null;
229                int row = 0;
230                while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) {
231                        if (outArr == null) {
232                                retMat = new Matrix(nLines, lineValues.length);
233                                outArr = retMat.getArray();
234                        }
235
236                        for (int col = 0; col < lineValues.length; col++) {
237                                outArr[row][col] = Double.parseDouble(lineValues[col]);
238                        }
239                        row++;
240                }
241                return retMat;
242        }
243
244        public static void main(String[] args) throws IOException {
245                new HandWritingNeuralNetSANDIA();
246        }
247
248        @Override
249        public void algorithmStarted(IterativeAlgorithm algorithm) {
250                System.out.println("Learning neural network");
251        }
252
253        @Override
254        public void algorithmEnded(IterativeAlgorithm algorithm) {
255                System.out.println("Done Learning!");
256        }
257
258        @Override
259        public void stepStarted(IterativeAlgorithm algorithm) {
260                System.out.println("... starting step: " + algorithm.getIteration());
261        }
262
263        @Override
264        public void stepEnded(IterativeAlgorithm algorithm) {
265                System.out.println("... ending step: " + algorithm.getIteration());
266        }
267}