001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.demos.ml.linear.data; 031 032import gov.sandia.cognition.learning.data.DefaultInputOutputPair; 033import gov.sandia.cognition.learning.data.InputOutputPair; 034import gov.sandia.cognition.learning.function.kernel.LinearKernel; 035import gov.sandia.cognition.math.matrix.VectorFactory; 036 037import java.io.File; 038import java.io.IOException; 039import java.io.PrintWriter; 040import java.util.ArrayList; 041import java.util.Arrays; 042import java.util.Collection; 043import java.util.List; 044 045import no.uib.cipr.matrix.Vector; 046 047import org.openimaj.image.DisplayUtilities; 048import org.openimaj.image.MBFImage; 049import org.openimaj.image.colour.ColourSpace; 050import org.openimaj.image.colour.RGBColour; 051import org.openimaj.math.geometry.line.Line2d; 052import org.openimaj.math.geometry.point.Point2d; 053import org.openimaj.math.geometry.point.Point2dImpl; 054import org.openimaj.math.geometry.shape.Circle; 055import org.openimaj.ml.linear.data.LinearPerceptronDataGenerator; 056import org.openimaj.ml.linear.kernel.LinearVectorKernel; 057import org.openimaj.ml.linear.learner.perceptron.DoubleArrayKernelPerceptron; 058import org.openimaj.ml.linear.learner.perceptron.MeanCenteredKernelPerceptron; 059import org.openimaj.ml.linear.learner.perceptron.MeanCenteredProjectron; 060import org.openimaj.ml.linear.learner.perceptron.PerceptronClass; 061import org.openimaj.ml.linear.learner.perceptron.SimplePerceptron; 062import org.openimaj.util.pair.IndependentPair; 063import org.openimaj.util.stream.Stream; 064 065/** 066 * 067 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 068 */ 069public class DrawLinearData { 070 071 private static final int TOTAL_DATA_ITEMS = 1000; 072 private static final int SEED = 1; 073 074 /** 075 * @param args 076 * @throws IOException 077 */ 078 public static void main(String[] args) throws IOException { 079 final LinearPerceptronDataGenerator dg = dataGen(); 080 Stream<IndependentPair<double[], PerceptronClass>> dataStream; 081 drawData(dg); 082 writeData(new File("/Users/ss/Experiments/perceptron/test.data")); 083 // dataStream = new 084 // LimitedDataStream<double[],PerceptronClass>(dataGen(),TOTAL_DATA_ITEMS); 085 dataStream = new RepeatingDataStream<double[], PerceptronClass>(dataGen(), TOTAL_DATA_ITEMS); 086 final MeanCenteredKernelPerceptron mkp = new MeanCenteredKernelPerceptron(new LinearVectorKernel()); 087 // MatrixKernelPerceptron mkp = new MarginMeanCenteredPerceptron(new 088 // LinearVectorKernel(),10000d); 089 // MatrixKernelPerceptron mkp = new MeanCenteredProjectron(new 090 // LinearVectorKernel()); 091 // MatrixKernelPerceptron mkp = new Projectron(new 092 // LinearVectorKernel()); 093 // MatrixKernelPerceptron mkp = new 094 // ThresholdMatrixKernelPerceptron(0.01, 0, new LinearVectorKernel()); 095 // MatrixKernelPerceptron mkp = new MatrixKernelPerceptron(new 096 // LinearVectorKernel()); 097 // SimplePerceptron mkp = new SimplePerceptron(); 098 leanrnPoints(mkp, dataStream); 099 // leanrnPointsProjectron(); 100 // leanrnCogFound(); 101 } 102 103 private static void writeData(File file) throws IOException { 104 final LinearPerceptronDataGenerator gen = dataGen(); 105 final File pf = file.getParentFile(); 106 if (!pf.exists()) 107 pf.mkdirs(); 108 final PrintWriter fw = new PrintWriter(file); 109 for (int i = 0; i < TOTAL_DATA_ITEMS; i++) { 110 final IndependentPair<double[], PerceptronClass> d = gen.generate(); 111 fw.println(Arrays.toString(d.firstObject())); 112 fw.println(d.secondObject() == PerceptronClass.TRUE ? 1 : 0); 113 } 114 fw.close(); 115 } 116 117 private static void drawData(LinearPerceptronDataGenerator dg) { 118 final Stream<IndependentPair<double[], PerceptronClass>> dataStream = new LimitedDataStream<double[], PerceptronClass>( 119 dg, TOTAL_DATA_ITEMS); 120 final Vector origin = dg.getOrigin(); 121 122 final Vector dir = dg.getPlane()[0]; 123 final Point2d lineStart = start(origin, dir); 124 final Point2d lineEnd = end(origin, dir); 125 final Line2d line = new Line2d(lineStart, lineEnd); 126 127 drawPoints(dataStream, line); 128 } 129 130 private static LinearPerceptronDataGenerator dataGen() { 131 final LinearPerceptronDataGenerator dg = new LinearPerceptronDataGenerator(300, 2, 0.3, SEED); 132 return dg; 133 } 134 135 private static void learnCogFound() { 136 final LinearPerceptronDataGenerator dg = dataGen(); 137 final gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector> mkp = new gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron<gov.sandia.cognition.math.matrix.Vector>( 138 new LinearKernel()); 139 mkp.learn(createData()); 140 // System.out.println(mkp.getErrorCount()); 141 142 } 143 144 private static Collection<? extends InputOutputPair<? extends gov.sandia.cognition.math.matrix.Vector, Boolean>> 145 createData() 146 { 147 final List<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>> ret = new ArrayList<InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean>>(); 148 final LinearPerceptronDataGenerator dg = dataGen(); 149 for (int i = 0; i < TOTAL_DATA_ITEMS; i++) { 150 final IndependentPair<double[], PerceptronClass> pointClass = dg.generate(); 151 final double[] pc = pointClass.firstObject(); 152 final PerceptronClass pcc = pointClass.secondObject(); 153 final boolean bool = pcc.equals(PerceptronClass.TRUE); 154 final gov.sandia.cognition.math.matrix.Vector vec = VectorFactory.getDenseDefault().copyArray(pc); 155 final InputOutputPair<gov.sandia.cognition.math.matrix.Vector, Boolean> item = DefaultInputOutputPair.create( 156 vec, bool); 157 ret.add(item); 158 } 159 System.out.println("Data created"); 160 return ret; 161 } 162 163 private static void drawMkpLine(DoubleArrayKernelPerceptron mkp) { 164 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB); 165 166 final List<double[]> sup = mkp.getSupports(); 167 final List<Double> weights = mkp.getWeights(); 168 final double bias = mkp.getBias(); 169 System.out.println("Bias: " + bias); 170 double[] startD = null; 171 double[] endD = null; 172 173 double[] mean = new double[2]; 174 if (mkp instanceof MeanCenteredKernelPerceptron) { 175 mean = ((MeanCenteredKernelPerceptron) mkp).getMean(); 176 } else if (mkp instanceof MeanCenteredProjectron) { 177 mean = ((MeanCenteredProjectron) mkp).getMean(); 178 } 179 startD = LinearVectorKernel.getPlanePoint(sup, weights, bias, -mean[0], Double.NaN); 180 endD = LinearVectorKernel.getPlanePoint(sup, weights, bias, img.getWidth() - mean[0], Double.NaN); 181 startD[0] += mean[0]; 182 startD[1] += mean[1]; 183 endD[0] += mean[0]; 184 endD[1] += mean[1]; 185 drawLine(img, startD, endD); 186 } 187 188 private static void drawLine(MBFImage img, double[] startD, double[] endD) { 189 final Point2d lineStart = new Point2dImpl((float) startD[0], (float) startD[1]); 190 final Point2d lineEnd = new Point2dImpl((float) endD[0], (float) endD[1]); 191 192 final Line2d line = new Line2d(lineStart, lineEnd); 193 // System.out.println("Drawing: " + line); 194 img.drawLine(line, 3, RGBColour.GREEN); 195 // img.drawPoint(new Point2dImpl((float)origin.get(0),(float) 196 // origin.get(1)), RGBColour.RED, 5); 197 DisplayUtilities.displayName(img, "line"); 198 } 199 200 private static void leanrnPoints(SimplePerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter) { 201 int errors = 0; 202 int i = 0; 203 for (final IndependentPair<double[], PerceptronClass> pointClass : iter) { 204 i++; 205 final double[] pc = pointClass.firstObject(); 206 final PerceptronClass cls = pointClass.getSecondObject(); 207 final int correctedClass = cls == PerceptronClass.TRUE ? 1 : 0; 208 final IndependentPair<double[], Integer> correctedPair = IndependentPair.pair(pc, correctedClass); 209 final boolean errorBefore = mkp.predict(correctedPair.firstObject()) != correctedPair.secondObject(); 210 mkp.process(pc, correctedClass); 211 if (errorBefore) { 212 errors++; 213 214 } 215 if (i % TOTAL_DATA_ITEMS == 0) { 216 if (errors == 0) { 217 break; 218 } else { 219 i = 0; 220 errors = 0; 221 } 222 } 223 } 224 drawSpLine(mkp); 225 } 226 227 private static void drawSpLine(SimplePerceptron mkp) { 228 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB); 229 final double[] startD = new double[] { 0, Double.NaN }; 230 final double[] endD = new double[] { img.getWidth(), Double.NaN }; 231 232 drawLine(img, mkp.computeHyperplanePoint(startD), mkp.computeHyperplanePoint(endD)); 233 } 234 235 private static void 236 leanrnPoints(DoubleArrayKernelPerceptron mkp, Iterable<IndependentPair<double[], PerceptronClass>> iter) 237 { 238 int i = 0; 239 int errors = 0; 240 for (final IndependentPair<double[], PerceptronClass> pointClass : iter) { 241 i++; 242 final double[] pc = pointClass.firstObject(); 243 final PerceptronClass cls = pointClass.getSecondObject(); 244 final int errorBefore = mkp.getErrors(); 245 mkp.process(pc, cls); 246 System.out.println("b: " + mkp.getBias() + " w: " 247 + Arrays.toString(LinearVectorKernel.getDirection(mkp.getSupports(), mkp.getWeights()))); 248 if (errorBefore != mkp.getErrors()) { 249 errors++; 250 } 251 if (i % TOTAL_DATA_ITEMS == 0) { 252 if (errors == 0) { 253 break; 254 } else { 255 i = 0; 256 errors = 0; 257 } 258 } 259 } 260 drawMkpLine(mkp); 261 System.out.println(mkp.getSupports().size()); 262 } 263 264 private static void drawPoints(Stream<IndependentPair<double[], PerceptronClass>> dataStream, Line2d line) { 265 final MBFImage img = new MBFImage(300, 300, ColourSpace.RGB); 266 267 img.drawLine(line, 3, RGBColour.BLUE); 268 269 for (final IndependentPair<double[], PerceptronClass> pointClass : dataStream) { 270 271 final double[] pc = pointClass.firstObject(); 272 final Point2dImpl point = new Point2dImpl((float) pc[0], (float) pc[1]); 273 final PerceptronClass cls = pointClass.getSecondObject(); 274 switch (cls) { 275 case TRUE: 276 img.drawShapeFilled(new Circle(point, 5), RGBColour.GREEN); 277 break; 278 case FALSE: 279 img.drawShape(new Circle(point, 5), 3, RGBColour.RED); 280 break; 281 case NONE: 282 throw new RuntimeException("NOPE"); 283 } 284 } 285 DisplayUtilities.displayName(img, "random"); 286 } 287 288 private static Point2d end(Vector origin, Vector dir) { 289 final Vector ret = origin.copy().add(10000, dir); 290 return new Point2dImpl((float) ret.get(0), (float) ret.get(1)); 291 } 292 293 private static Point2d start(Vector origin, Vector dir) { 294 final Vector ret = origin.copy().add(-10000, dir); 295 return new Point2dImpl((float) ret.get(0), (float) ret.get(1)); 296 } 297 298}