1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  package org.openimaj.ml.linear.learner.perceptron;
31  
32  import java.util.Arrays;
33  import java.util.List;
34  
35  import org.openimaj.math.model.EstimatableModel;
36  import org.openimaj.ml.linear.learner.OnlineLearner;
37  import org.openimaj.util.pair.IndependentPair;
38  
39  
40  
41  
42  
43  public class SimplePerceptron implements OnlineLearner<double[], Integer>, EstimatableModel<double[], Integer> {
44  	private static final double DEFAULT_LEARNING_RATE = 0.01;
45  	private static final int DEFAULT_ITERATIONS = 1000;
46  	double alpha = DEFAULT_LEARNING_RATE;
47  	private double[] w;
48  	private int iterations = DEFAULT_ITERATIONS;
49  
50  	private SimplePerceptron(double[] w) {
51  		this.w = w;
52  	}
53  
54  	
55  
56  
57  	public SimplePerceptron() {
58  	}
59  
60  	@Override
61  	public void process(double[] pt, Integer clazz) {
62  		
63  		
64  		if (w == null) {
65  			initW(pt.length);
66  		}
67  		final int y = predict(pt);
68  		System.out.println("w: " + Arrays.toString(w));
69  		w[0] = w[0] + alpha * (clazz - y);
70  		for (int i = 0; i < pt.length; i++) {
71  			w[i + 1] = w[i + 1] + alpha * (clazz - y) * pt[i];
72  		}
73  		
74  	}
75  
76  	private void initW(int length) {
77  		w = new double[length + 1];
78  		w[0] = 1;
79  	}
80  
81  	@Override
82  	public Integer predict(double[] x) {
83  		if (w == null)
84  			return 0;
85  		return (w[0] + project(x)) > 0 ? 1 : 0;
86  	}
87  
88  	private double project(double[] x) {
89  		double sum = 0;
90  		for (int i = 0; i < x.length; i++) {
91  			sum += x[i] * w[i + 1];
92  		}
93  		return sum;
94  	}
95  
96  	@Override
97  	public boolean estimate(List<? extends IndependentPair<double[], Integer>> data) {
98  		this.w = new double[] { 1, 0, 0 };
99  
100 		for (int i = 0; i < iterations; i++) {
101 			iteration(data);
102 
103 			final double error = calculateError(data);
104 			if (error < 0.01)
105 				break;
106 		}
107 		return true;
108 	}
109 
110 	private void iteration(List<? extends IndependentPair<double[], Integer>> pts) {
111 		for (int i = 0; i < pts.size(); i++) {
112 			final IndependentPair<double[], Integer> pair = pts.get(i);
113 			process(pair.firstObject(), pair.secondObject());
114 		}
115 	}
116 
117 	@Override
118 	public int numItemsToEstimate() {
119 		return 1;
120 	}
121 
122 	protected double calculateError(List<? extends IndependentPair<double[], Integer>> pts) {
123 		double error = 0;
124 
125 		for (int i = 0; i < pts.size(); i++) {
126 			final IndependentPair<double[], Integer> pair = pts.get(i);
127 			error += Math.abs(predict(pts.get(i).firstObject()) - pair.secondObject());
128 		}
129 
130 		return error / pts.size();
131 	}
132 
133 	
134 
135 
136 
137 
138 
139 
140 
141 
142 
143 	public double[] computeHyperplanePoint(double[] x) {
144 		double total = w[0];
145 		int nanindex = -1;
146 		final double[] ret = new double[x.length];
147 		for (int i = 0; i < x.length; i++) {
148 			double value = x[i];
149 			if (nanindex != -1 && Double.isNaN(value)) {
150 				value = 0;
151 			}
152 			else if (Double.isNaN(value)) {
153 				nanindex = i;
154 				continue;
155 			}
156 			ret[i] = value;
157 			total += w[i + 1] * value;
158 		}
159 		if (nanindex != -1)
160 			ret[nanindex] = total / -w[nanindex + 1];
161 		return ret;
162 	}
163 
164 	@Override
165 	public SimplePerceptron clone() {
166 		return new SimplePerceptron(w);
167 	}
168 
169 	public double[] getWeights() {
170 		return this.w;
171 	}
172 }