1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.demos.ml.linear.data;
31
32 import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
33 import gov.sandia.cognition.learning.data.InputOutputPair;
34 import gov.sandia.cognition.learning.function.kernel.LinearKernel;
35 import gov.sandia.cognition.math.matrix.VectorFactory;
36
37 import java.io.File;
38 import java.io.IOException;
39 import java.io.PrintWriter;
40 import java.util.ArrayList;
41 import java.util.Arrays;
42 import java.util.Collection;
43 import java.util.List;
44
45 import no.uib.cipr.matrix.Vector;
46
47 import org.openimaj.image.DisplayUtilities;
48 import org.openimaj.image.MBFImage;
49 import org.openimaj.image.colour.ColourSpace;
50 import org.openimaj.image.colour.RGBColour;
51 import org.openimaj.math.geometry.line.Line2d;
52 import org.openimaj.math.geometry.point.Point2d;
53 import org.openimaj.math.geometry.point.Point2dImpl;
54 import org.openimaj.math.geometry.shape.Circle;
55 import org.openimaj.ml.linear.data.LinearPerceptronDataGenerator;
56 import org.openimaj.ml.linear.kernel.LinearVectorKernel;
57 import org.openimaj.ml.linear.learner.perceptron.DoubleArrayKernelPerceptron;
58 import org.openimaj.ml.linear.learner.perceptron.MeanCenteredKernelPerceptron;
59 import org.openimaj.ml.linear.learner.perceptron.MeanCenteredProjectron;
60 import org.openimaj.ml.linear.learner.perceptron.PerceptronClass;
61 import org.openimaj.ml.linear.learner.perceptron.SimplePerceptron;
62 import org.openimaj.util.pair.IndependentPair;
63 import org.openimaj.util.stream.Stream;
64
65
66
67
68
69 public class DrawLinearData {
70
71 private static final int TOTAL_DATA_ITEMS = 1000;
72 private static final int SEED = 1;
73
74
75
76
77
78 public static void main(String[] args) throws IOException {
79 final LinearPerceptronDataGenerator dg = dataGen();
80 Stream<IndependentPair<double[], PerceptronClass>> dataStream;
81 drawData(dg);
82 writeData(new File("/Users/ss/Experiments/perceptron/test.data"));
83
84
85 dataStream = new RepeatingDataStream<double[], PerceptronClass>(dataGen(), TOTAL_DATA_ITEMS);
86 final MeanCenteredKernelPerceptron mkp = new MeanCenteredKernelPerceptron(new LinearVectorKernel());
87
88
89
90
91
92
93
94
95
96
97
98 leanrnPoints(mkp, dataStream);
99
100
101 }
102
103 private static void writeData(File file) throws IOException {
104 final LinearPerceptronDataGenerator gen = dataGen();
105 final File pf = file.getParentFile();
106 if (!pf.exists())
107 pf.mkdirs();
108 final PrintWriter fw = new PrintWriter(file);
109 for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
110 final IndependentPair<double[], PerceptronClass> d = gen.generate();
111 fw.println(Arrays.toString(d.firstObject()));
112 fw.println(d.secondObject() == PerceptronClass.TRUE ? 1 : 0);
113 }
114 fw.close();
115 }
116
117 private static void drawData(LinearPerceptronDataGenerator dg) {
118 final Stream<IndependentPair<double[], PerceptronClass>> dataStream = new LimitedDataStream<double[], PerceptronClass>(
119 dg, TOTAL_DATA_ITEMS);
120 final Vector origin = dg.getOrigin();
121
122 final Vector dir = dg.getPlane()[0];
123 final Point2d lineStart = start(origin, dir);
124 final Point2d lineEnd = end(origin, dir);
125 final Line2d line = new Line2d(lineStart, lineEnd);
126
127 drawPoints(dataStream, line);
128 }
129
130 private static LinearPerceptronDataGenerator dataGen() {
131 final LinearPerceptronDataGenerator dg = new LinearPerceptronDataGenerator(300, 2, 0.3, SEED);
132 return dg;
133 }
134
135 private static void learnCogFound() {
136 final LinearPerceptronDataGenerator dg = dataGen();
137 final gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector> mkp = new gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector>(
138 new LinearKernel());
139 mkp.learn(createData());
140
141
142 }
143
144 private static Collection<? extends InputOutputPair<? extends gov.sandia.cognition.math.matrix.Vector, Boolean>>
145 createData()
146 {
147 final List<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>> ret = new ArrayList<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>>();
148 final LinearPerceptronDataGenerator dg = dataGen();
149 for (int i = 0; i < TOTAL_DATA_ITEMS; i++) {
150 final IndependentPair<double[], PerceptronClass> pointClass = dg.generate();
151 final double[] pc = pointClass.firstObject();
152 final PerceptronClass pcc = pointClass.secondObject();
153 final boolean bool = pcc.equals(PerceptronClass.TRUE);
154 final gov.sandia.cognition.math.matrix.Vector vec = VectorFactory.getDenseDefault().copyArray(pc);
155 final InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean> item = DefaultInputOutputPair.create(
156 vec, bool);
157 ret.add(item);
158 }
159 System.out.println("Data created");
160 return ret;
161 }
162
163 private static void drawMkpLine(DoubleArrayKernelPerceptron mkp) {
164 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
165
166 final List<double[]> sup = mkp.getSupports();
167 final List<Double> weights = mkp.getWeights();
168 final double bias = mkp.getBias();
169 System.out.println("Bias: " + bias);
170 double[] startD = null;
171 double[] endD = null;
172
173 double[] mean = new double[2];
174 if (mkp instanceof MeanCenteredKernelPerceptron) {
175 mean = ((MeanCenteredKernelPerceptron) mkp).getMean();
176 } else if (mkp instanceof MeanCenteredProjectron) {
177 mean = ((MeanCenteredProjectron) mkp).getMean();
178 }
179 startD = LinearVectorKernel.getPlanePoint(sup, weights, bias, -mean[0], Double.NaN);
180 endD = LinearVectorKernel.getPlanePoint(sup, weights, bias, img.getWidth() - mean[0], Double.NaN);
181 startD[0] += mean[0];
182 startD[1] += mean[1];
183 endD[0] += mean[0];
184 endD[1] += mean[1];
185 drawLine(img, startD, endD);
186 }
187
188 private static void drawLine(MBFImage img, double[] startD, double[] endD) {
189 final Point2d lineStart = new Point2dImpl((float) startD[0], (float) startD[1]);
190 final Point2d lineEnd = new Point2dImpl((float) endD[0], (float) endD[1]);
191
192 final Line2d line = new Line2d(lineStart, lineEnd);
193
194 img.drawLine(line, 3, RGBColour.GREEN);
195
196
197 DisplayUtilities.displayName(img, "line");
198 }
199
200 private static void leanrnPoints(SimplePerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter) {
201 int errors = 0;
202 int i = 0;
203 for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
204 i++;
205 final double[] pc = pointClass.firstObject();
206 final PerceptronClass cls = pointClass.getSecondObject();
207 final int correctedClass = cls == PerceptronClass.TRUE ? 1 : 0;
208 final IndependentPair<double[], Integer> correctedPair = IndependentPair.pair(pc, correctedClass);
209 final boolean errorBefore = mkp.predict(correctedPair.firstObject()) != correctedPair.secondObject();
210 mkp.process(pc, correctedClass);
211 if (errorBefore) {
212 errors++;
213
214 }
215 if (i % TOTAL_DATA_ITEMS == 0) {
216 if (errors == 0) {
217 break;
218 } else {
219 i = 0;
220 errors = 0;
221 }
222 }
223 }
224 drawSpLine(mkp);
225 }
226
227 private static void drawSpLine(SimplePerceptron mkp) {
228 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
229 final double[] startD = new double[] { 0, Double.NaN };
230 final double[] endD = new double[] { img.getWidth(), Double.NaN };
231
232 drawLine(img, mkp.computeHyperplanePoint(startD), mkp.computeHyperplanePoint(endD));
233 }
234
235 private static void
236 leanrnPoints(DoubleArrayKernelPerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter)
237 {
238 int i = 0;
239 int errors = 0;
240 for (final IndependentPair<double[], PerceptronClass> pointClass : iter) {
241 i++;
242 final double[] pc = pointClass.firstObject();
243 final PerceptronClass cls = pointClass.getSecondObject();
244 final int errorBefore = mkp.getErrors();
245 mkp.process(pc, cls);
246 System.out.println("b: " + mkp.getBias() + " w: "
247 + Arrays.toString(LinearVectorKernel.getDirection(mkp.getSupports(), mkp.getWeights())));
248 if (errorBefore != mkp.getErrors()) {
249 errors++;
250 }
251 if (i % TOTAL_DATA_ITEMS == 0) {
252 if (errors == 0) {
253 break;
254 } else {
255 i = 0;
256 errors = 0;
257 }
258 }
259 }
260 drawMkpLine(mkp);
261 System.out.println(mkp.getSupports().size());
262 }
263
264 private static void drawPoints(Stream<IndependentPair<double[], PerceptronClass>> dataStream, Line2d line) {
265 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB);
266
267 img.drawLine(line, 3, RGBColour.BLUE);
268
269 for (final IndependentPair<double[], PerceptronClass> pointClass : dataStream) {
270
271 final double[] pc = pointClass.firstObject();
272 final Point2dImpl point = new Point2dImpl((float) pc[0], (float) pc[1]);
273 final PerceptronClass cls = pointClass.getSecondObject();
274 switch (cls) {
275 case TRUE:
276 img.drawShapeFilled(new Circle(point, 5), RGBColour.GREEN);
277 break;
278 case FALSE:
279 img.drawShape(new Circle(point, 5), 3, RGBColour.RED);
280 break;
281 case NONE:
282 throw new RuntimeException("NOPE");
283 }
284 }
285 DisplayUtilities.displayName(img, "random");
286 }
287
288 private static Point2d end(Vector origin, Vector dir) {
289 final Vector ret = origin.copy().add(10000, dir);
290 return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
291 }
292
293 private static Point2d start(Vector origin, Vector dir) {
294 final Vector ret = origin.copy().add(-10000, dir);
295 return new Point2dImpl((float) ret.get(0), (float) ret.get(1));
296 }
297
298 }