001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner.perceptron;
031
032import java.util.Arrays;
033import java.util.List;
034
035import org.openimaj.math.model.EstimatableModel;
036import org.openimaj.ml.linear.learner.OnlineLearner;
037import org.openimaj.util.pair.IndependentPair;
038
039/**
040 * 
041 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
042 */
043public class SimplePerceptron implements OnlineLearner<double[], Integer>, EstimatableModel<double[], Integer> {
044        private static final double DEFAULT_LEARNING_RATE = 0.01;
045        private static final int DEFAULT_ITERATIONS = 1000;
046        double alpha = DEFAULT_LEARNING_RATE;
047        private double[] w;
048        private int iterations = DEFAULT_ITERATIONS;
049
050        private SimplePerceptron(double[] w) {
051                this.w = w;
052        }
053
054        /**
055         * 
056         */
057        public SimplePerceptron() {
058        }
059
060        @Override
061        public void process(double[] pt, Integer clazz) {
062                // System.out.println("Testing: " + Arrays.toString(pt) + " = " +
063                // clazz);
064                if (w == null) {
065                        initW(pt.length);
066                }
067                final int y = predict(pt);
068                System.out.println("w: " + Arrays.toString(w));
069                w[0] = w[0] + alpha * (clazz - y);
070                for (int i = 0; i < pt.length; i++) {
071                        w[i + 1] = w[i + 1] + alpha * (clazz - y) * pt[i];
072                }
073                // System.out.println("neww: " + Arrays.toString(w));
074        }
075
076        private void initW(int length) {
077                w = new double[length + 1];
078                w[0] = 1;
079        }
080
081        @Override
082        public Integer predict(double[] x) {
083                if (w == null)
084                        return 0;
085                return (w[0] + project(x)) > 0 ? 1 : 0;
086        }
087
088        private double project(double[] x) {
089                double sum = 0;
090                for (int i = 0; i < x.length; i++) {
091                        sum += x[i] * w[i + 1];
092                }
093                return sum;
094        }
095
096        @Override
097        public boolean estimate(List<? extends IndependentPair<double[], Integer>> data) {
098                this.w = new double[] { 1, 0, 0 };
099
100                for (int i = 0; i < iterations; i++) {
101                        iteration(data);
102
103                        final double error = calculateError(data);
104                        if (error < 0.01)
105                                break;
106                }
107                return true;
108        }
109
110        private void iteration(List<? extends IndependentPair<double[], Integer>> pts) {
111                for (int i = 0; i < pts.size(); i++) {
112                        final IndependentPair<double[], Integer> pair = pts.get(i);
113                        process(pair.firstObject(), pair.secondObject());
114                }
115        }
116
117        @Override
118        public int numItemsToEstimate() {
119                return 1;
120        }
121
122        protected double calculateError(List<? extends IndependentPair<double[], Integer>> pts) {
123                double error = 0;
124
125                for (int i = 0; i < pts.size(); i++) {
126                        final IndependentPair<double[], Integer> pair = pts.get(i);
127                        error += Math.abs(predict(pts.get(i).firstObject()) - pair.secondObject());
128                }
129
130                return error / pts.size();
131        }
132
133        /**
134         * Compute NaN-coordinate of a point on the hyperplane given
135         * non-NaN-coordinates. Only one x coordinate may be nan. If more NaN are
136         * seen after the first they are assumed to be 0
137         * 
138         * @param x
139         *            the coordinates, only one may be NaN, all others must be
140         *            provided
141         * @return the y-ordinate
142         */
143        public double[] computeHyperplanePoint(double[] x) {
144                double total = w[0];
145                int nanindex = -1;
146                final double[] ret = new double[x.length];
147                for (int i = 0; i < x.length; i++) {
148                        double value = x[i];
149                        if (nanindex != -1 && Double.isNaN(value)) {
150                                value = 0;
151                        }
152                        else if (Double.isNaN(value)) {
153                                nanindex = i;
154                                continue;
155                        }
156                        ret[i] = value;
157                        total += w[i + 1] * value;
158                }
159                if (nanindex != -1)
160                        ret[nanindex] = total / -w[nanindex + 1];
161                return ret;
162        }
163
164        @Override
165        public SimplePerceptron clone() {
166                return new SimplePerceptron(w);
167        }
168
169        public double[] getWeights() {
170                return this.w;
171        }
172}