View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.demos.ml.linear.data;
31  
32  import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
33  import gov.sandia.cognition.learning.data.InputOutputPair;
34  import gov.sandia.cognition.learning.function.kernel.LinearKernel;
35  import gov.sandia.cognition.math.matrix.VectorFactory;
36  
37  import java.io.File;
38  import java.io.IOException;
39  import java.io.PrintWriter;
40  import java.util.ArrayList;
41  import java.util.Arrays;
42  import java.util.Collection;
43  import java.util.List;
44  
45  import no.uib.cipr.matrix.Vector;
46  
47  import org.openimaj.image.DisplayUtilities;
48  import org.openimaj.image.MBFImage;
49  import org.openimaj.image.colour.ColourSpace;
50  import org.openimaj.image.colour.RGBColour;
51  import org.openimaj.math.geometry.line.Line2d;
52  import org.openimaj.math.geometry.point.Point2d;
53  import org.openimaj.math.geometry.point.Point2dImpl;
54  import org.openimaj.math.geometry.shape.Circle;
55  import org.openimaj.ml.linear.data.LinearPerceptronDataGenerator;
56  import org.openimaj.ml.linear.kernel.LinearVectorKernel;
57  import org.openimaj.ml.linear.learner.perceptron.DoubleArrayKernelPerceptron;
58  import org.openimaj.ml.linear.learner.perceptron.MeanCenteredKernelPerceptron;
59  import org.openimaj.ml.linear.learner.perceptron.MeanCenteredProjectron;
60  import org.openimaj.ml.linear.learner.perceptron.PerceptronClass;
61  import org.openimaj.ml.linear.learner.perceptron.SimplePerceptron;
62  import org.openimaj.util.pair.IndependentPair;
63  import org.openimaj.util.stream.Stream;
64  
65  /**
66   *
67   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
68   */
69  public class DrawLinearData {
70  
71  	private static final int TOTAL_DATA_ITEMS = 1000;
72  	private static final int SEED = 1;
73  
74  	/**
75  	 * @param args
76  	 * @throws IOException
77  	 */
78  	public static void main(String[] args) throws IOException {
79  		final LinearPerceptronDataGenerator dg = dataGen();
80  		Stream<IndependentPair<double[], PerceptronClass>> dataStream;
81  		drawData(dg);
82  		writeData(new File("/Users/ss/Experiments/perceptron/test.data"));
83  		// dataStream = new
84  		// LimitedDataStream<double[],PerceptronClass>(dataGen(),TOTAL_DATA_ITEMS);
85  		dataStream = new RepeatingDataStream<double[], PerceptronClass>(dataGen(), TOTAL_DATA_ITEMS);
86  		final MeanCenteredKernelPerceptron mkp = new MeanCenteredKernelPerceptron(new LinearVectorKernel());
87  		// MatrixKernelPerceptron mkp = new MarginMeanCenteredPerceptron(new
88  		// LinearVectorKernel(),10000d);
89  		// MatrixKernelPerceptron mkp = new MeanCenteredProjectron(new
90  		// LinearVectorKernel());
91  		// MatrixKernelPerceptron mkp = new Projectron(new
92  		// LinearVectorKernel());
93  		// MatrixKernelPerceptron mkp = new
94  		// ThresholdMatrixKernelPerceptron(0.01, 0, new LinearVectorKernel());
95  		// MatrixKernelPerceptron mkp = new MatrixKernelPerceptron(new
96  		// LinearVectorKernel());
97  		// SimplePerceptron mkp = new SimplePerceptron();
98  		leanrnPoints(mkp, dataStream);
99  		// leanrnPointsProjectron();
100 		// leanrnCogFound();
101 	}
102 
103 	private static void writeData(File file) throws IOException {
104 		final LinearPerceptronDataGenerator gen = dataGen();
105 		final File pf = file.getParentFile();
106 		if (!pf.exists())
107 			pf.mkdirs();
108 		final PrintWriter fw = new PrintWriter(file);
109 		for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
110 			final IndependentPair<double[], PerceptronClass> d = gen.generate();
111 			fw.println(Arrays.toString(d.firstObject()));
112 			fw.println(d.secondObject() == PerceptronClass.TRUE ? 1 : 0);
113 		}
114 		fw.close();
115 	}
116 
117 	private static void drawData(LinearPerceptronDataGenerator dg) {
118 		final Stream<IndependentPair<double[], PerceptronClass>> dataStream = new LimitedDataStream<double[], PerceptronClass>(
119 				dg, TOTAL_DATA_ITEMS);
120 		final Vector origin = dg.getOrigin();
121 
122 		final Vector dir = dg.getPlane()[0];
123 		final Point2d lineStart = start(origin, dir);
124 		final Point2d lineEnd = end(origin, dir);
125 		final Line2d line = new Line2d(lineStart, lineEnd);
126 
127 		drawPoints(dataStream, line);
128 	}
129 
130 	private static LinearPerceptronDataGenerator dataGen() {
131 		final LinearPerceptronDataGenerator dg = new LinearPerceptronDataGenerator(300, 2, 0.3, SEED);
132 		return dg;
133 	}
134 
135 	private static void learnCogFound() {
136 		final LinearPerceptronDataGenerator dg = dataGen();
137 		final gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector> mkp = new gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector>(
138 				new LinearKernel());
139 		mkp.learn(createData());
140 		// System.out.println(mkp.getErrorCount());
141 
142 	}
143 
144 	private static Collection<? extends InputOutputPair<? extends gov.sandia.cognition.math.matrix.Vector, Boolean>>
145 	createData()
146 	{
147 		final List<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>> ret = new ArrayList<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>>();
148 		final LinearPerceptronDataGenerator dg = dataGen();
149 		for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
150 			final IndependentPair<double[], PerceptronClass> pointClass = dg.generate();
151 			final double[] pc = pointClass.firstObject();
152 			final PerceptronClass pcc = pointClass.secondObject();
153 			final boolean bool = pcc.equals(PerceptronClass.TRUE);
154 			final gov.sandia.cognition.math.matrix.Vector vec = VectorFactory.getDenseDefault().copyArray(pc);
155 			final InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean> item = DefaultInputOutputPair.create(
156 					vec, bool);
157 			ret.add(item);
158 		}
159 		System.out.println("Data created");
160 		return ret;
161 	}
162 
163 	private static void drawMkpLine(DoubleArrayKernelPerceptron mkp) {
164 		final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
165 
166 		final List<double[]> sup = mkp.getSupports();
167 		final List<Double> weights = mkp.getWeights();
168 		final double bias = mkp.getBias();
169 		System.out.println("Bias: " + bias);
170 		double[] startD = null;
171 		double[] endD = null;
172 
173 		double[] mean = new double[2];
174 		if (mkp instanceof MeanCenteredKernelPerceptron) {
175 			mean = ((MeanCenteredKernelPerceptron) mkp).getMean();
176 		} else if (mkp instanceof MeanCenteredProjectron) {
177 			mean = ((MeanCenteredProjectron) mkp).getMean();
178 		}
179 		startD = LinearVectorKernel.getPlanePoint(sup, weights, bias, -mean[0], Double.NaN);
180 		endD = LinearVectorKernel.getPlanePoint(sup, weights, bias, img.getWidth() - mean[0], Double.NaN);
181 		startD[0] += mean[0];
182 		startD[1] += mean[1];
183 		endD[0] += mean[0];
184 		endD[1] += mean[1];
185 		drawLine(img, startD, endD);
186 	}
187 
188 	private static void drawLine(MBFImage img, double[] startD, double[] endD) {
189 		final Point2d lineStart = new Point2dImpl((float) startD[0], (float) startD[1]);
190 		final Point2d lineEnd = new Point2dImpl((float) endD[0], (float) endD[1]);
191 
192 		final Line2d line = new Line2d(lineStart, lineEnd);
193 		// System.out.println("Drawing: " + line);
194 		img.drawLine(line, 3, RGBColour.GREEN);
195 		// img.drawPoint(new Point2dImpl((float)origin.get(0),(float)
196 		// origin.get(1)), RGBColour.RED, 5);
197 		DisplayUtilities.displayName(img, "line");
198 	}
199 
200 	private static void leanrnPoints(SimplePerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter) {
201 		int errors = 0;
202 		int i = 0;
203 		for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
204 			i++;
205 			final double[] pc = pointClass.firstObject();
206 			final PerceptronClass cls = pointClass.getSecondObject();
207 			final int correctedClass = cls == PerceptronClass.TRUE ? 1 : 0;
208 			final IndependentPair<double[], Integer> correctedPair = IndependentPair.pair(pc, correctedClass);
209 			final boolean errorBefore = mkp.predict(correctedPair.firstObject()) != correctedPair.secondObject();
210 			mkp.process(pc, correctedClass);
211 			if (errorBefore) {
212 				errors++;
213 
214 			}
215 			if (i % TOTAL_DATA_ITEMS == 0) {
216 				if (errors == 0) {
217 					break;
218 				} else {
219 					i = 0;
220 					errors = 0;
221 				}
222 			}
223 		}
224 		drawSpLine(mkp);
225 	}
226 
227 	private static void drawSpLine(SimplePerceptron mkp) {
228 		final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
229 		final double[] startD = new double[] { 0, Double.NaN };
230 		final double[] endD = new double[] { img.getWidth(), Double.NaN };
231 
232 		drawLine(img, mkp.computeHyperplanePoint(startD), mkp.computeHyperplanePoint(endD));
233 	}
234 
235 	private static void
236 	leanrnPoints(DoubleArrayKernelPerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter)
237 	{
238 		int i = 0;
239 		int errors = 0;
240 		for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
241 			i++;
242 			final double[] pc = pointClass.firstObject();
243 			final PerceptronClass cls = pointClass.getSecondObject();
244 			final int errorBefore = mkp.getErrors();
245 			mkp.process(pc, cls);
246 			System.out.println("b: " + mkp.getBias() + " w: "
247 					+ Arrays.toString(LinearVectorKernel.getDirection(mkp.getSupports(), mkp.getWeights())));
248 			if (errorBefore != mkp.getErrors()) {
249 				errors++;
250 			}
251 			if (i % TOTAL_DATA_ITEMS == 0) {
252 				if (errors == 0) {
253 					break;
254 				} else {
255 					i = 0;
256 					errors = 0;
257 				}
258 			}
259 		}
260 		drawMkpLine(mkp);
261 		System.out.println(mkp.getSupports().size());
262 	}
263 
264 	private static void drawPoints(Stream<IndependentPair<double[], PerceptronClass>> dataStream, Line2d line) {
265 		final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
266 
267 		img.drawLine(line, 3, RGBColour.BLUE);
268 
269 		for (final IndependentPair<double[], PerceptronClass> pointClass : dataStream) {
270 
271 			final double[] pc = pointClass.firstObject();
272 			final Point2dImpl point = new Point2dImpl((float) pc[0], (float) pc[1]);
273 			final PerceptronClass cls = pointClass.getSecondObject();
274 			switch (cls) {
275 			case TRUE:
276 				img.drawShapeFilled(new Circle(point, 5), RGBColour.GREEN);
277 				break;
278 			case FALSE:
279 				img.drawShape(new Circle(point, 5), 3, RGBColour.RED);
280 				break;
281 			case NONE:
282 				throw new RuntimeException("NOPE");
283 			}
284 		}
285 		DisplayUtilities.displayName(img, "random");
286 	}
287 
288 	private static Point2d end(Vector origin, Vector dir) {
289 		final Vector ret = origin.copy().add(10000, dir);
290 		return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
291 	}
292 
293 	private static Point2d start(Vector origin, Vector dir) {
294 		final Vector ret = origin.copy().add(-10000, dir);
295 		return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
296 	}
297 
298 }