1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.sgdsvm;
31
32 import java.util.List;
33
34 import org.apache.commons.math.random.MersenneTwister;
35 import org.openimaj.feature.FloatFV;
36 import org.openimaj.feature.FloatFVComparison;
37 import org.openimaj.util.array.ArrayUtils;
38 import org.openimaj.util.array.SparseFloatArray;
39 import org.openimaj.util.array.SparseFloatArray.Entry;
40 import org.openimaj.util.array.SparseHashedFloatArray;
41
42 import gnu.trove.list.array.TDoubleArrayList;
43
44 public class SvmSgd implements Cloneable {
45 Loss LOSS = LossFunctions.HingeLoss;
46 boolean BIAS = true;
47 boolean REGULARIZED_BIAS = false;
48
49 public double lambda;
50 public double eta0;
51 FloatFV w;
52 double wDivisor;
53 double wBias;
54 double t;
55
56 public SvmSgd(int dim, double lambda) {
57 this(dim, lambda, 0);
58 }
59
60 public SvmSgd(int dim, double lambda, double eta0) {
61 this.lambda = lambda;
62 this.eta0 = eta0;
63 this.w = new FloatFV(dim);
64 this.wDivisor = 1;
65 this.wBias = 0;
66 this.t = 0;
67 }
68
69 private double dot(FloatFV v1, SparseFloatArray v2) {
70 double d = 0;
71 for (final Entry e : v2.entries()) {
72 d += e.value * v1.values[e.index];
73 }
74
75 return d;
76 }
77
78 private double dot(FloatFV v1, FloatFV v2) {
79 return FloatFVComparison.INNER_PRODUCT.compare(v1, v2);
80 }
81
82 private void add(FloatFV y, SparseFloatArray x, double d) {
83
84
85 for (final Entry e : x.entries()) {
86 y.values[e.index] += e.value * d;
87 }
88 }
89
90
91 public void renorm() {
92 if (wDivisor != 1.0) {
93 ArrayUtils.multiply(w.values, (float) (1.0 / wDivisor));
94
95 wDivisor = 1.0;
96 }
97 }
98
99
100 public double wnorm() {
101 double norm = dot(w, w) / wDivisor / wDivisor;
102
103 if (REGULARIZED_BIAS)
104 norm += wBias * wBias;
105 return norm;
106 }
107
108
109 public double testOne(final SparseFloatArray x, double y, double[] ploss, double[] pnerr) {
110 final double s = dot(w, x) / wDivisor + wBias;
111 if (ploss != null)
112 ploss[0] += LOSS.loss(s, y);
113 if (pnerr != null)
114 pnerr[0] += (s * y <= 0) ? 1 : 0;
115 return s;
116 }
117
118
119 public void trainOne(final SparseFloatArray x, double y, double eta) {
120 final double s = dot(w, x) / wDivisor + wBias;
121
122 wDivisor = wDivisor / (1 - eta * lambda);
123 if (wDivisor > 1e5)
124 renorm();
125
126 final double d = LOSS.dloss(s, y);
127 if (d != 0)
128 add(w, x, eta * d * wDivisor);
129
130
131 if (BIAS) {
132 final double etab = eta * 0.01;
133 if (REGULARIZED_BIAS) {
134 wBias *= (1 - etab * lambda);
135 }
136 wBias += etab * d;
137 }
138 }
139
140 @Override
141 protected SvmSgd clone() {
142 SvmSgd clone;
143 try {
144 clone = (SvmSgd) super.clone();
145 } catch (final CloneNotSupportedException e) {
146 throw new RuntimeException(e);
147 }
148 clone.w = clone.w.clone();
149 return clone;
150 }
151
152
153 public void train(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
154 System.out.println("Training on [" + imin + ", " + imax + "].");
155 assert (imin <= imax);
156 assert (eta0 > 0);
157 for (int i = imin; i <= imax; i++) {
158 final double eta = eta0 / (1 + lambda * eta0 * t);
159 trainOne(xp[i], yp[i], eta);
160 t += 1;
161 }
162
163 System.out.format("wNorm=%.6f", wnorm());
164 if (BIAS) {
165
166 System.out.format(" wBias=%.6f", wBias);
167 }
168 System.out.println();
169
170 }
171
172
173 public void train(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
174 System.out.println("Training on [" + imin + ", " + imax + "].");
175 assert (imin <= imax);
176 assert (eta0 > 0);
177 for (int i = imin; i <= imax; i++) {
178 final double eta = eta0 / (1 + lambda * eta0 * t);
179 trainOne(xp.get(i), yp.get(i), eta);
180 t += 1;
181 }
182
183 System.out.format("wNorm=%.6f", wnorm());
184 if (BIAS) {
185
186 System.out.format(" wBias=%.6f", wBias);
187 }
188 System.out.println();
189
190 }
191
192
193 public void test(int imin, int imax, SparseFloatArray[] xp, double[] yp, String prefix) {
194
195
196 System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
197 assert (imin <= imax);
198 final double nerr[] = { 0 };
199 final double loss[] = { 0 };
200 for (int i = imin; i <= imax; i++)
201 testOne(xp[i], yp[i], loss, nerr);
202 nerr[0] = nerr[0] / (imax - imin + 1);
203 loss[0] = loss[0] / (imax - imin + 1);
204 final double cost = loss[0] + 0.5 * lambda * wnorm();
205
206
207
208
209
210 System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
211 + String.format("%2.4f", 100 * nerr[0]) + "%");
212 }
213
214
215 public void test(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, String prefix) {
216
217
218 System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
219 assert (imin <= imax);
220 final double nerr[] = { 0 };
221 final double loss[] = { 0 };
222 for (int i = imin; i <= imax; i++)
223 testOne(xp.get(i), yp.get(i), loss, nerr);
224 nerr[0] = nerr[0] / (imax - imin + 1);
225 loss[0] = loss[0] / (imax - imin + 1);
226 final double cost = loss[0] + 0.5 * lambda * wnorm();
227
228
229
230
231
232 System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
233 + String.format("%2.4f", 100 * nerr[0]) + "%");
234 }
235
236
237 public double evaluateEta(int imin, int imax, SparseFloatArray[] xp, double[] yp, double eta) {
238 final SvmSgd clone = this.clone();
239 assert (imin <= imax);
240 for (int i = imin; i <= imax; i++)
241 clone.trainOne(xp[i], yp[i], eta);
242 final double loss[] = { 0 };
243 double cost = 0;
244 for (int i = imin; i <= imax; i++)
245 clone.testOne(xp[i], yp[i], loss, null);
246 loss[0] = loss[0] / (imax - imin + 1);
247 cost = loss[0] + 0.5 * lambda * clone.wnorm();
248
249 System.out.println("Trying eta=" + eta + " yields cost " + cost);
250 return cost;
251 }
252
253
254 public double evaluateEta(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, double eta) {
255 final SvmSgd clone = this.clone();
256 assert (imin <= imax);
257 for (int i = imin; i <= imax; i++)
258 clone.trainOne(xp.get(i), yp.get(i), eta);
259 final double loss[] = { 0 };
260 double cost = 0;
261 for (int i = imin; i <= imax; i++)
262 clone.testOne(xp.get(i), yp.get(i), loss, null);
263 loss[0] = loss[0] / (imax - imin + 1);
264 cost = loss[0] + 0.5 * lambda * clone.wnorm();
265
266 System.out.println("Trying eta=" + eta + " yields cost " + cost);
267 return cost;
268 }
269
270 public void determineEta0(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
271 final double factor = 2.0;
272 double loEta = 1;
273 double loCost = evaluateEta(imin, imax, xp, yp, loEta);
274 double hiEta = loEta * factor;
275 double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
276 if (loCost < hiCost)
277 while (loCost < hiCost) {
278 hiEta = loEta;
279 hiCost = loCost;
280 loEta = hiEta / factor;
281 loCost = evaluateEta(imin, imax, xp, yp, loEta);
282 }
283 else if (hiCost < loCost)
284 while (hiCost < loCost) {
285 loEta = hiEta;
286 loCost = hiCost;
287 hiEta = loEta * factor;
288 hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
289 }
290 eta0 = loEta;
291
292 System.out.println("# Using eta0=" + eta0 + "\n");
293 }
294
295 public void determineEta0(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
296 final double factor = 2.0;
297 double loEta = 1;
298 double loCost = evaluateEta(imin, imax, xp, yp, loEta);
299 double hiEta = loEta * factor;
300 double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
301 if (loCost < hiCost)
302 while (loCost < hiCost) {
303 hiEta = loEta;
304 hiCost = loCost;
305 loEta = hiEta / factor;
306 loCost = evaluateEta(imin, imax, xp, yp, loEta);
307 }
308 else if (hiCost < loCost)
309 while (hiCost < loCost) {
310 loEta = hiEta;
311 loCost = hiCost;
312 hiEta = loEta * factor;
313 hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
314 }
315 eta0 = loEta;
316
317 System.out.println("# Using eta0=" + eta0 + "\n");
318 }
319
320 public static void main(String[] args) {
321 final MersenneTwister mt = new MersenneTwister();
322 final SparseFloatArray[] tr = new SparseFloatArray[10000];
323 final double[] clz = new double[tr.length];
324 for (int i = 0; i < tr.length; i++) {
325 tr[i] = new SparseHashedFloatArray(2);
326
327 if (i < tr.length / 2) {
328 tr[i].set(0, (float) (mt.nextGaussian() - 2));
329 tr[i].set(1, (float) (mt.nextGaussian() - 2));
330 clz[i] = -1;
331 } else {
332 tr[i].set(0, (float) (mt.nextGaussian() + 2));
333 tr[i].set(1, (float) (mt.nextGaussian() + 2));
334 clz[i] = 1;
335 }
336 System.out.println(tr[i].values()[0] + " " + clz[i]);
337 }
338
339 final SvmSgd svm = new SvmSgd(2, 1e-5);
340 svm.BIAS = true;
341 svm.REGULARIZED_BIAS = false;
342 svm.determineEta0(0, tr.length - 1, tr, clz);
343 for (int i = 0; i < 10; i++) {
344 System.out.println();
345 svm.train(0, tr.length - 1, tr, clz);
346 svm.test(0, tr.length - 1, tr, clz, "training ");
347 System.out.println(svm.w);
348 System.out.println(svm.wBias);
349 System.out.println(svm.wDivisor);
350 }
351
352
353
354
355
356
357 }
358 }