1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.linear.learner.perceptron;
31
32 import java.util.Arrays;
33 import java.util.List;
34
35 import org.openimaj.math.model.EstimatableModel;
36 import org.openimaj.ml.linear.learner.OnlineLearner;
37 import org.openimaj.util.pair.IndependentPair;
38
39
40
41
42
43 public class SimplePerceptron implements OnlineLearner<double[], Integer>, EstimatableModel<double[], Integer> {
44 private static final double DEFAULT_LEARNING_RATE = 0.01;
45 private static final int DEFAULT_ITERATIONS = 1000;
46 double alpha = DEFAULT_LEARNING_RATE;
47 private double[] w;
48 private int iterations = DEFAULT_ITERATIONS;
49
50 private SimplePerceptron(double[] w) {
51 this.w = w;
52 }
53
54
55
56
57 public SimplePerceptron() {
58 }
59
60 @Override
61 public void process(double[] pt, Integer clazz) {
62
63
64 if (w == null) {
65 initW(pt.length);
66 }
67 final int y = predict(pt);
68 System.out.println("w: " + Arrays.toString(w));
69 w[0] = w[0] + alpha * (clazz - y);
70 for (int i = 0; i < pt.length; i++) {
71 w[i + 1] = w[i + 1] + alpha * (clazz - y) * pt[i];
72 }
73
74 }
75
76 private void initW(int length) {
77 w = new double[length + 1];
78 w[0] = 1;
79 }
80
81 @Override
82 public Integer predict(double[] x) {
83 if (w == null)
84 return 0;
85 return (w[0] + project(x)) > 0 ? 1 : 0;
86 }
87
88 private double project(double[] x) {
89 double sum = 0;
90 for (int i = 0; i < x.length; i++) {
91 sum += x[i] * w[i + 1];
92 }
93 return sum;
94 }
95
96 @Override
97 public boolean estimate(List<? extends IndependentPair<double[], Integer>> data) {
98 this.w = new double[] { 1, 0, 0 };
99
100 for (int i = 0; i < iterations; i++) {
101 iteration(data);
102
103 final double error = calculateError(data);
104 if (error < 0.01)
105 break;
106 }
107 return true;
108 }
109
110 private void iteration(List<? extends IndependentPair<double[], Integer>> pts) {
111 for (int i = 0; i < pts.size(); i++) {
112 final IndependentPair<double[], Integer> pair = pts.get(i);
113 process(pair.firstObject(), pair.secondObject());
114 }
115 }
116
117 @Override
118 public int numItemsToEstimate() {
119 return 1;
120 }
121
122 protected double calculateError(List<? extends IndependentPair<double[], Integer>> pts) {
123 double error = 0;
124
125 for (int i = 0; i < pts.size(); i++) {
126 final IndependentPair<double[], Integer> pair = pts.get(i);
127 error += Math.abs(predict(pts.get(i).firstObject()) - pair.secondObject());
128 }
129
130 return error / pts.size();
131 }
132
133
134
135
136
137
138
139
140
141
142
143 public double[] computeHyperplanePoint(double[] x) {
144 double total = w[0];
145 int nanindex = -1;
146 final double[] ret = new double[x.length];
147 for (int i = 0; i < x.length; i++) {
148 double value = x[i];
149 if (nanindex != -1 && Double.isNaN(value)) {
150 value = 0;
151 }
152 else if (Double.isNaN(value)) {
153 nanindex = i;
154 continue;
155 }
156 ret[i] = value;
157 total += w[i + 1] * value;
158 }
159 if (nanindex != -1)
160 ret[nanindex] = total / -w[nanindex + 1];
161 return ret;
162 }
163
164 @Override
165 public SimplePerceptron clone() {
166 return new SimplePerceptron(w);
167 }
168
169 public double[] getWeights() {
170 return this.w;
171 }
172 }