001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.ml.linear.data;
031
032import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
033import gov.sandia.cognition.learning.data.InputOutputPair;
034import gov.sandia.cognition.learning.function.kernel.LinearKernel;
035import gov.sandia.cognition.math.matrix.VectorFactory;
036
037import java.io.File;
038import java.io.IOException;
039import java.io.PrintWriter;
040import java.util.ArrayList;
041import java.util.Arrays;
042import java.util.Collection;
043import java.util.List;
044
045import no.uib.cipr.matrix.Vector;
046
047import org.openimaj.image.DisplayUtilities;
048import org.openimaj.image.MBFImage;
049import org.openimaj.image.colour.ColourSpace;
050import org.openimaj.image.colour.RGBColour;
051import org.openimaj.math.geometry.line.Line2d;
052import org.openimaj.math.geometry.point.Point2d;
053import org.openimaj.math.geometry.point.Point2dImpl;
054import org.openimaj.math.geometry.shape.Circle;
055import org.openimaj.ml.linear.data.LinearPerceptronDataGenerator;
056import org.openimaj.ml.linear.kernel.LinearVectorKernel;
057import org.openimaj.ml.linear.learner.perceptron.DoubleArrayKernelPerceptron;
058import org.openimaj.ml.linear.learner.perceptron.MeanCenteredKernelPerceptron;
059import org.openimaj.ml.linear.learner.perceptron.MeanCenteredProjectron;
060import org.openimaj.ml.linear.learner.perceptron.PerceptronClass;
061import org.openimaj.ml.linear.learner.perceptron.SimplePerceptron;
062import org.openimaj.util.pair.IndependentPair;
063import org.openimaj.util.stream.Stream;
064
065/**
066 *
067 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
068 */
069public class DrawLinearData {
070
071        private static final int TOTAL_DATA_ITEMS = 1000;
072        private static final int SEED = 1;
073
074        /**
075         * @param args
076         * @throws IOException
077         */
078        public static void main(String[] args) throws IOException {
079                final LinearPerceptronDataGenerator dg = dataGen();
080                Stream<IndependentPair<double[], PerceptronClass>> dataStream;
081                drawData(dg);
082                writeData(new File("/Users/ss/Experiments/perceptron/test.data"));
083                // dataStream = new
084                // LimitedDataStream<double[],PerceptronClass>(dataGen(),TOTAL_DATA_ITEMS);
085                dataStream = new RepeatingDataStream<double[], PerceptronClass>(dataGen(), TOTAL_DATA_ITEMS);
086                final MeanCenteredKernelPerceptron mkp = new MeanCenteredKernelPerceptron(new LinearVectorKernel());
087                // MatrixKernelPerceptron mkp = new MarginMeanCenteredPerceptron(new
088                // LinearVectorKernel(),10000d);
089                // MatrixKernelPerceptron mkp = new MeanCenteredProjectron(new
090                // LinearVectorKernel());
091                // MatrixKernelPerceptron mkp = new Projectron(new
092                // LinearVectorKernel());
093                // MatrixKernelPerceptron mkp = new
094                // ThresholdMatrixKernelPerceptron(0.01, 0, new LinearVectorKernel());
095                // MatrixKernelPerceptron mkp = new MatrixKernelPerceptron(new
096                // LinearVectorKernel());
097                // SimplePerceptron mkp = new SimplePerceptron();
098                leanrnPoints(mkp, dataStream);
099                // leanrnPointsProjectron();
100                // leanrnCogFound();
101        }
102
103        private static void writeData(File file) throws IOException {
104                final LinearPerceptronDataGenerator gen = dataGen();
105                final File pf = file.getParentFile();
106                if (!pf.exists())
107                        pf.mkdirs();
108                final PrintWriter fw = new PrintWriter(file);
109                for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
110                        final IndependentPair<double[], PerceptronClass> d = gen.generate();
111                        fw.println(Arrays.toString(d.firstObject()));
112                        fw.println(d.secondObject() == PerceptronClass.TRUE ? 1 : 0);
113                }
114                fw.close();
115        }
116
117        private static void drawData(LinearPerceptronDataGenerator dg) {
118                final Stream<IndependentPair<double[], PerceptronClass>> dataStream = new LimitedDataStream<double[], PerceptronClass>(
119                                dg, TOTAL_DATA_ITEMS);
120                final Vector origin = dg.getOrigin();
121
122                final Vector dir = dg.getPlane()[0];
123                final Point2d lineStart = start(origin, dir);
124                final Point2d lineEnd = end(origin, dir);
125                final Line2d line = new Line2d(lineStart, lineEnd);
126
127                drawPoints(dataStream, line);
128        }
129
130        private static LinearPerceptronDataGenerator dataGen() {
131                final LinearPerceptronDataGenerator dg = new LinearPerceptronDataGenerator(300, 2, 0.3, SEED);
132                return dg;
133        }
134
135        private static void learnCogFound() {
136                final LinearPerceptronDataGenerator dg = dataGen();
137                final gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector> mkp = new gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector>(
138                                new LinearKernel());
139                mkp.learn(createData());
140                // System.out.println(mkp.getErrorCount());
141
142        }
143
144        private static Collection<? extends InputOutputPair<? extends gov.sandia.cognition.math.matrix.Vector, Boolean>>
145        createData()
146        {
147                final List<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>> ret = new ArrayList<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>>();
148                final LinearPerceptronDataGenerator dg = dataGen();
149                for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
150                        final IndependentPair<double[], PerceptronClass> pointClass = dg.generate();
151                        final double[] pc = pointClass.firstObject();
152                        final PerceptronClass pcc = pointClass.secondObject();
153                        final boolean bool = pcc.equals(PerceptronClass.TRUE);
154                        final gov.sandia.cognition.math.matrix.Vector vec = VectorFactory.getDenseDefault().copyArray(pc);
155                        final InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean> item = DefaultInputOutputPair.create(
156                                        vec, bool);
157                        ret.add(item);
158                }
159                System.out.println("Data created");
160                return ret;
161        }
162
163        private static void drawMkpLine(DoubleArrayKernelPerceptron mkp) {
164                final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
165
166                final List<double[]> sup = mkp.getSupports();
167                final List<Double> weights = mkp.getWeights();
168                final double bias = mkp.getBias();
169                System.out.println("Bias: " + bias);
170                double[] startD = null;
171                double[] endD = null;
172
173                double[] mean = new double[2];
174                if (mkp instanceof MeanCenteredKernelPerceptron) {
175                        mean = ((MeanCenteredKernelPerceptron) mkp).getMean();
176                } else if (mkp instanceof MeanCenteredProjectron) {
177                        mean = ((MeanCenteredProjectron) mkp).getMean();
178                }
179                startD = LinearVectorKernel.getPlanePoint(sup, weights, bias, -mean[0], Double.NaN);
180                endD = LinearVectorKernel.getPlanePoint(sup, weights, bias, img.getWidth() - mean[0], Double.NaN);
181                startD[0] += mean[0];
182                startD[1] += mean[1];
183                endD[0] += mean[0];
184                endD[1] += mean[1];
185                drawLine(img, startD, endD);
186        }
187
188        private static void drawLine(MBFImage img, double[] startD, double[] endD) {
189                final Point2d lineStart = new Point2dImpl((float) startD[0], (float) startD[1]);
190                final Point2d lineEnd = new Point2dImpl((float) endD[0], (float) endD[1]);
191
192                final Line2d line = new Line2d(lineStart, lineEnd);
193                // System.out.println("Drawing: " + line);
194                img.drawLine(line, 3, RGBColour.GREEN);
195                // img.drawPoint(new Point2dImpl((float)origin.get(0),(float)
196                // origin.get(1)), RGBColour.RED, 5);
197                DisplayUtilities.displayName(img, "line");
198        }
199
200        private static void leanrnPoints(SimplePerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter) {
201                int errors = 0;
202                int i = 0;
203                for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
204                        i++;
205                        final double[] pc = pointClass.firstObject();
206                        final PerceptronClass cls = pointClass.getSecondObject();
207                        final int correctedClass = cls == PerceptronClass.TRUE ? 1 : 0;
208                        final IndependentPair<double[], Integer> correctedPair = IndependentPair.pair(pc, correctedClass);
209                        final boolean errorBefore = mkp.predict(correctedPair.firstObject()) != correctedPair.secondObject();
210                        mkp.process(pc, correctedClass);
211                        if (errorBefore) {
212                                errors++;
213
214                        }
215                        if (i % TOTAL_DATA_ITEMS == 0) {
216                                if (errors == 0) {
217                                        break;
218                                } else {
219                                        i = 0;
220                                        errors = 0;
221                                }
222                        }
223                }
224                drawSpLine(mkp);
225        }
226
227        private static void drawSpLine(SimplePerceptron mkp) {
228                final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
229                final double[] startD = new double[] { 0, Double.NaN };
230                final double[] endD = new double[] { img.getWidth(), Double.NaN };
231
232                drawLine(img, mkp.computeHyperplanePoint(startD), mkp.computeHyperplanePoint(endD));
233        }
234
235        private static void
236        leanrnPoints(DoubleArrayKernelPerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter)
237        {
238                int i = 0;
239                int errors = 0;
240                for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
241                        i++;
242                        final double[] pc = pointClass.firstObject();
243                        final PerceptronClass cls = pointClass.getSecondObject();
244                        final int errorBefore = mkp.getErrors();
245                        mkp.process(pc, cls);
246                        System.out.println("b: " + mkp.getBias() + " w: "
247                                        + Arrays.toString(LinearVectorKernel.getDirection(mkp.getSupports(), mkp.getWeights())));
248                        if (errorBefore != mkp.getErrors()) {
249                                errors++;
250                        }
251                        if (i % TOTAL_DATA_ITEMS == 0) {
252                                if (errors == 0) {
253                                        break;
254                                } else {
255                                        i = 0;
256                                        errors = 0;
257                                }
258                        }
259                }
260                drawMkpLine(mkp);
261                System.out.println(mkp.getSupports().size());
262        }
263
264        private static void drawPoints(Stream<IndependentPair<double[], PerceptronClass>> dataStream, Line2d line) {
265                final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
266
267                img.drawLine(line, 3, RGBColour.BLUE);
268
269                for (final IndependentPair<double[], PerceptronClass> pointClass : dataStream) {
270
271                        final double[] pc = pointClass.firstObject();
272                        final Point2dImpl point = new Point2dImpl((float) pc[0], (float) pc[1]);
273                        final PerceptronClass cls = pointClass.getSecondObject();
274                        switch (cls) {
275                        case TRUE:
276                                img.drawShapeFilled(new Circle(point, 5), RGBColour.GREEN);
277                                break;
278                        case FALSE:
279                                img.drawShape(new Circle(point, 5), 3, RGBColour.RED);
280                                break;
281                        case NONE:
282                                throw new RuntimeException("NOPE");
283                        }
284                }
285                DisplayUtilities.displayName(img, "random");
286        }
287
288        private static Point2d end(Vector origin, Vector dir) {
289                final Vector ret = origin.copy().add(10000, dir);
290                return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
291        }
292
293        private static Point2d start(Vector origin, Vector dir) {
294                final Vector ret = origin.copy().add(-10000, dir);
295                return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
296        }
297
298}