001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.neuralnet;
031
032import gnu.trove.map.hash.TIntIntHashMap;
033import gnu.trove.map.hash.TIntObjectHashMap;
034import gnu.trove.procedure.TIntObjectProcedure;
035import gov.sandia.cognition.io.CSVUtility;
036
037import java.io.BufferedReader;
038import java.io.File;
039import java.io.FileWriter;
040import java.io.IOException;
041import java.io.InputStream;
042import java.io.InputStreamReader;
043import java.io.PrintWriter;
044import java.util.ArrayList;
045import java.util.Iterator;
046import java.util.List;
047
048import org.encog.engine.network.activation.ActivationStep;
049import org.encog.mathutil.rbf.RBFEnum;
050import org.encog.ml.MLRegression;
051import org.encog.ml.data.MLDataPair;
052import org.encog.ml.data.MLDataSet;
053import org.encog.ml.data.specific.CSVNeuralDataSet;
054import org.encog.ml.svm.SVM;
055import org.encog.ml.svm.training.SVMTrain;
056import org.encog.ml.train.MLTrain;
057import org.encog.neural.cpn.CPN;
058import org.encog.neural.cpn.training.TrainInstar;
059import org.encog.neural.cpn.training.TrainOutstar;
060import org.encog.neural.data.basic.BasicNeuralData;
061import org.encog.neural.neat.NEATPopulation;
062import org.encog.neural.neat.training.NEATTraining;
063import org.encog.neural.networks.training.CalculateScore;
064import org.encog.neural.networks.training.TrainingSetScore;
065import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
066import org.encog.neural.rbf.RBFNetwork;
067import org.encog.util.simple.EncogUtility;
068import org.openimaj.util.pair.IndependentPair;
069
070import Jama.Matrix;
071
072/**
073 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
074 * 
075 *         Just some experiments using the sandia cognitive foundary neural
076 *         nets.
077 * 
078 */
079public class HandWritingNeuralNetENCOG {
080
081        /**
082         * Default location of inputs
083         */
084        public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputouput.csv";
085
086        private int maxTests = 10;
087
088        private TIntIntHashMap examples;
089        private TIntObjectHashMap<List<IndependentPair<double[], double[]>>> tests;
090
091        private int totalTests = 0;
092
093        private MLRegression network;
094
095        private MLDataSet training;
096
097        /**
098         * @throws IOException
099         *             Load X input and y output from {@link #INPUT_LOCATION}
100         */
101        public HandWritingNeuralNetENCOG() throws IOException {
102
103                examples = new TIntIntHashMap();
104                this.tests = new TIntObjectHashMap<List<IndependentPair<double[], double[]>>>();
105                prepareDataCollection();
106                learnNeuralNet();
107                testNeuralNet();
108                // new HandWritingInputDisplay(this.training);
109        }
110
111        private void testNeuralNet() {
112                final double[][] xVals = new double[totalTests][];
113                final int[] yVals = new int[totalTests];
114                this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<double[], double[]>>>() {
115                        int done = 0;
116
117                        @Override
118                        public boolean execute(int number, List<IndependentPair<double[], double[]>> xypairs) {
119                                for (final IndependentPair<double[], double[]> xyval : xypairs) {
120                                        final double[] guessed = network.compute(new BasicNeuralData(xyval.firstObject())).getData(); // estimate
121                                        int maxIndex = 0;
122                                        double maxValue = 0;
123                                        for (int i = 0; i < guessed.length; i++) {
124                                                if (maxValue < guessed[i])
125                                                {
126                                                        maxValue = guessed[i];
127                                                        maxIndex = i;
128                                                }
129                                        }
130                                        xVals[done] = xyval.firstObject();
131                                        yVals[done] = (maxIndex + 1) % 10;
132                                        done++;
133                                }
134                                return true;
135                        }
136                });
137                new HandWritingInputDisplay(xVals, yVals);
138        }
139
140        private void prepareDataCollection() throws IOException {
141                final File tmp = File.createTempFile("data", ".csv");
142                final InputStream stream = HandWritingNeuralNetENCOG.class.getResourceAsStream(INPUT_LOCATION);
143                final BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
144                String line = null;
145                final PrintWriter writer = new PrintWriter(new FileWriter(tmp));
146                while ((line = reader.readLine()) != null) {
147                        writer.println(line);
148                }
149                writer.close();
150                reader.close();
151                training = new CSVNeuralDataSet(tmp.getAbsolutePath(), 400, 10, false);
152                final Iterator<MLDataPair> elementItr = this.training.iterator();
153                for (; elementItr.hasNext();) {
154                        final MLDataPair type = elementItr.next();
155                        final double[] yData = type.getIdealArray();
156                        final double[] xData = type.getInputArray();
157                        int yIndex = 0;
158                        while (yData[yIndex] != 1)
159                                yIndex++;
160                        final int currentCount = this.examples.adjustOrPutValue(yIndex, 1, 1);
161                        if (currentCount < this.maxTests) {
162
163                                List<IndependentPair<double[], double[]>> numberTest = this.tests.get(yIndex);
164                                if (numberTest == null) {
165                                        this.tests.put(yIndex, numberTest = new ArrayList<IndependentPair<double[], double[]>>());
166                                }
167                                numberTest.add(IndependentPair.pair(xData, yData));
168                                totalTests++;
169                        }
170                }
171
172        }
173
174        private void learnNeuralNet() {
175                // this.network = EncogUtility.simpleFeedForward(400, 100, 0, 10,
176                // false);
177                // MLTrain train = new Backpropagation(this.network, this.training);
178                // MLTrain train = new ResilientPropagation(this.network,
179                // this.training);
180
181                // this.network = withNEAT();
182                // this.network = withRBF();
183                // this.network = withSVM();
184                this.network = withResilieant();
185                // this.network = withCPN();
186        }
187
188        private MLRegression withNEAT() {
189                final NEATPopulation pop = new NEATPopulation(400, 10, 1000);
190                final CalculateScore score = new TrainingSetScore(this.training);
191                // train the neural network
192                final ActivationStep step = new ActivationStep();
193                step.setCenter(0.5);
194                pop.setOutputActivationFunction(step);
195                final MLTrain train = new NEATTraining(score, pop);
196                EncogUtility.trainToError(train, 0.01515);
197                return (MLRegression) train.getMethod();
198        }
199
200        private MLRegression withResilieant() {
201                final MLTrain train = new ResilientPropagation(EncogUtility.simpleFeedForward(400, 100, 0, 10, false),
202                                this.training);
203                EncogUtility.trainToError(train, 0.01515);
204                return (MLRegression) train.getMethod();
205        }
206
207        private MLRegression withSVM() {
208                final MLTrain train = new SVMTrain(new SVM(400, true), this.training);
209                EncogUtility.trainToError(train, 0.01515);
210                return (MLRegression) train.getMethod();
211        }
212
213        private MLRegression withRBF() {
214                final MLRegression train = new RBFNetwork(400, 20, 10, RBFEnum.Gaussian);
215                EncogUtility.trainToError(train, this.training, 0.01515);
216                return train;
217        }
218
219        private MLRegression withCPN() {
220                final CPN result = new CPN(400, 1000, 10, 1);
221                final MLTrain trainInstar = new TrainInstar(result, training, 0.1, false);
222                EncogUtility.trainToError(trainInstar, 0.01515);
223                final MLTrain trainOutstar = new TrainOutstar(result, training, 0.1);
224                EncogUtility.trainToError(trainOutstar, 0.01515);
225                return result;
226        }
227
228        private static <T> ArrayList<T> toArrayList(T[] values) {
229                final ArrayList<T> configList = new ArrayList<T>();
230                for (final T t : values) {
231                        configList.add(t);
232                }
233                return configList;
234        }
235
236        private Matrix fromCSV(BufferedReader bufferedReader, int nLines) throws IOException {
237
238                String[] lineValues = null;
239                double[][] outArr = null;
240                Matrix retMat = null;
241                int row = 0;
242                while ((lineValues = CSVUtility.nextNonEmptyLine(bufferedReader)) != null) {
243                        if (outArr == null) {
244                                retMat = new Matrix(nLines, lineValues.length);
245                                outArr = retMat.getArray();
246                        }
247
248                        for (int col = 0; col < lineValues.length; col++) {
249                                outArr[row][col] = Double.parseDouble(lineValues[col]);
250                        }
251                        row++;
252                }
253                return retMat;
254        }
255
256        public static void main(String[] args) throws IOException {
257                new HandWritingNeuralNetENCOG();
258        }
259}