View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress.sgdsvm;
31  
32  import java.util.List;
33  
34  import org.apache.commons.math.random.MersenneTwister;
35  import org.openimaj.feature.FloatFV;
36  import org.openimaj.feature.FloatFVComparison;
37  import org.openimaj.util.array.ArrayUtils;
38  import org.openimaj.util.array.SparseFloatArray;
39  import org.openimaj.util.array.SparseFloatArray.Entry;
40  import org.openimaj.util.array.SparseHashedFloatArray;
41  
42  import gnu.trove.list.array.TDoubleArrayList;
43  
44  public class SvmSgd implements Cloneable {
45  	Loss LOSS = LossFunctions.HingeLoss;
46  	boolean BIAS = true;
47  	boolean REGULARIZED_BIAS = false;
48  
49  	public double lambda;
50  	public double eta0;
51  	FloatFV w;
52  	double wDivisor;
53  	double wBias;
54  	double t;
55  
56  	public SvmSgd(int dim, double lambda) {
57  		this(dim, lambda, 0);
58  	}
59  
60  	public SvmSgd(int dim, double lambda, double eta0) {
61  		this.lambda = lambda;
62  		this.eta0 = eta0;
63  		this.w = new FloatFV(dim);
64  		this.wDivisor = 1;
65  		this.wBias = 0;
66  		this.t = 0;
67  	}
68  
69  	private double dot(FloatFV v1, SparseFloatArray v2) {
70  		double d = 0;
71  		for (final Entry e : v2.entries()) {
72  			d += e.value * v1.values[e.index];
73  		}
74  
75  		return d;
76  	}
77  
78  	private double dot(FloatFV v1, FloatFV v2) {
79  		return FloatFVComparison.INNER_PRODUCT.compare(v1, v2);
80  	}
81  
82  	private void add(FloatFV y, SparseFloatArray x, double d) {
83  		// w2 = w2 + x*w1
84  
85  		for (final Entry e : x.entries()) {
86  			y.values[e.index] += e.value * d;
87  		}
88  	}
89  
90  	/// Renormalize the weights
91  	public void renorm() {
92  		if (wDivisor != 1.0) {
93  			ArrayUtils.multiply(w.values, (float) (1.0 / wDivisor));
94  			// w.scale(1.0 / wDivisor);
95  			wDivisor = 1.0;
96  		}
97  	}
98  
99  	/// Compute the norm of the weights
100 	public double wnorm() {
101 		double norm = dot(w, w) / wDivisor / wDivisor;
102 
103 		if (REGULARIZED_BIAS)
104 			norm += wBias * wBias;
105 		return norm;
106 	}
107 
108 	/// Compute the output for one example.
109 	public double testOne(final SparseFloatArray x, double y, double[] ploss, double[] pnerr) {
110 		final double s = dot(w, x) / wDivisor + wBias;
111 		if (ploss != null)
112 			ploss[0] += LOSS.loss(s, y);
113 		if (pnerr != null)
114 			pnerr[0] += (s * y <= 0) ? 1 : 0;
115 		return s;
116 	}
117 
118 	/// Perform one iteration of the SGD algorithm with specified gains
119 	public void trainOne(final SparseFloatArray x, double y, double eta) {
120 		final double s = dot(w, x) / wDivisor + wBias;
121 		// update for regularization term
122 		wDivisor = wDivisor / (1 - eta * lambda);
123 		if (wDivisor > 1e5)
124 			renorm();
125 		// update for loss term
126 		final double d = LOSS.dloss(s, y);
127 		if (d != 0)
128 			add(w, x, eta * d * wDivisor);
129 
130 		// same for the bias
131 		if (BIAS) {
132 			final double etab = eta * 0.01;
133 			if (REGULARIZED_BIAS) {
134 				wBias *= (1 - etab * lambda);
135 			}
136 			wBias += etab * d;
137 		}
138 	}
139 
140 	@Override
141 	protected SvmSgd clone() {
142 		SvmSgd clone;
143 		try {
144 			clone = (SvmSgd) super.clone();
145 		} catch (final CloneNotSupportedException e) {
146 			throw new RuntimeException(e);
147 		}
148 		clone.w = clone.w.clone();
149 		return clone;
150 	}
151 
152 	/// Perform a training epoch
153 	public void train(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
154 		System.out.println("Training on [" + imin + ", " + imax + "].");
155 		assert (imin <= imax);
156 		assert (eta0 > 0);
157 		for (int i = imin; i <= imax; i++) {
158 			final double eta = eta0 / (1 + lambda * eta0 * t);
159 			trainOne(xp[i], yp[i], eta);
160 			t += 1;
161 		}
162 		// cout << prefix << setprecision(6) << "wNorm=" << wnorm();
163 		System.out.format("wNorm=%.6f", wnorm());
164 		if (BIAS) {
165 			// cout << " wBias=" << wBias;
166 			System.out.format(" wBias=%.6f", wBias);
167 		}
168 		System.out.println();
169 		// cout << endl;
170 	}
171 
172 	/// Perform a training epoch
173 	public void train(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
174 		System.out.println("Training on [" + imin + ", " + imax + "].");
175 		assert (imin <= imax);
176 		assert (eta0 > 0);
177 		for (int i = imin; i <= imax; i++) {
178 			final double eta = eta0 / (1 + lambda * eta0 * t);
179 			trainOne(xp.get(i), yp.get(i), eta);
180 			t += 1;
181 		}
182 		// cout << prefix << setprecision(6) << "wNorm=" << wnorm();
183 		System.out.format("wNorm=%.6f", wnorm());
184 		if (BIAS) {
185 			// cout << " wBias=" << wBias;
186 			System.out.format(" wBias=%.6f", wBias);
187 		}
188 		System.out.println();
189 		// cout << endl;
190 	}
191 
192 	/// Perform a test pass
193 	public void test(int imin, int imax, SparseFloatArray[] xp, double[] yp, String prefix) {
194 		// cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
195 		// endl;
196 		System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
197 		assert (imin <= imax);
198 		final double nerr[] = { 0 };
199 		final double loss[] = { 0 };
200 		for (int i = imin; i <= imax; i++)
201 			testOne(xp[i], yp[i], loss, nerr);
202 		nerr[0] = nerr[0] / (imax - imin + 1);
203 		loss[0] = loss[0] / (imax - imin + 1);
204 		final double cost = loss[0] + 0.5 * lambda * wnorm();
205 		// cout << prefix
206 		// << "Loss=" << setprecision(12) << loss
207 		// << " Cost=" << setprecision(12) << cost
208 		// << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
209 		// << endl;
210 		System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
211 				+ String.format("%2.4f", 100 * nerr[0]) + "%");
212 	}
213 
214 	/// Perform a test pass
215 	public void test(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, String prefix) {
216 		// cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
217 		// endl;
218 		System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
219 		assert (imin <= imax);
220 		final double nerr[] = { 0 };
221 		final double loss[] = { 0 };
222 		for (int i = imin; i <= imax; i++)
223 			testOne(xp.get(i), yp.get(i), loss, nerr);
224 		nerr[0] = nerr[0] / (imax - imin + 1);
225 		loss[0] = loss[0] / (imax - imin + 1);
226 		final double cost = loss[0] + 0.5 * lambda * wnorm();
227 		// cout << prefix
228 		// << "Loss=" << setprecision(12) << loss
229 		// << " Cost=" << setprecision(12) << cost
230 		// << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
231 		// << endl;
232 		System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
233 				+ String.format("%2.4f", 100 * nerr[0]) + "%");
234 	}
235 
236 	/// Perform one epoch with fixed eta and return cost
237 	public double evaluateEta(int imin, int imax, SparseFloatArray[] xp, double[] yp, double eta) {
238 		final SvmSgd clone = this.clone(); // take a copy of the current state
239 		assert (imin <= imax);
240 		for (int i = imin; i <= imax; i++)
241 			clone.trainOne(xp[i], yp[i], eta);
242 		final double loss[] = { 0 };
243 		double cost = 0;
244 		for (int i = imin; i <= imax; i++)
245 			clone.testOne(xp[i], yp[i], loss, null);
246 		loss[0] = loss[0] / (imax - imin + 1);
247 		cost = loss[0] + 0.5 * lambda * clone.wnorm();
248 		// cout << "Trying eta=" << eta << " yields cost " << cost << endl;
249 		System.out.println("Trying eta=" + eta + " yields cost " + cost);
250 		return cost;
251 	}
252 
253 	/// Perform one epoch with fixed eta and return cost
254 	public double evaluateEta(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, double eta) {
255 		final SvmSgd clone = this.clone(); // take a copy of the current state
256 		assert (imin <= imax);
257 		for (int i = imin; i <= imax; i++)
258 			clone.trainOne(xp.get(i), yp.get(i), eta);
259 		final double loss[] = { 0 };
260 		double cost = 0;
261 		for (int i = imin; i <= imax; i++)
262 			clone.testOne(xp.get(i), yp.get(i), loss, null);
263 		loss[0] = loss[0] / (imax - imin + 1);
264 		cost = loss[0] + 0.5 * lambda * clone.wnorm();
265 		// cout << "Trying eta=" << eta << " yields cost " << cost << endl;
266 		System.out.println("Trying eta=" + eta + " yields cost " + cost);
267 		return cost;
268 	}
269 
270 	public void determineEta0(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
271 		final double factor = 2.0;
272 		double loEta = 1;
273 		double loCost = evaluateEta(imin, imax, xp, yp, loEta);
274 		double hiEta = loEta * factor;
275 		double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
276 		if (loCost < hiCost)
277 			while (loCost < hiCost) {
278 				hiEta = loEta;
279 				hiCost = loCost;
280 				loEta = hiEta / factor;
281 				loCost = evaluateEta(imin, imax, xp, yp, loEta);
282 			}
283 		else if (hiCost < loCost)
284 			while (hiCost < loCost) {
285 				loEta = hiEta;
286 				loCost = hiCost;
287 				hiEta = loEta * factor;
288 				hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
289 			}
290 		eta0 = loEta;
291 		// cout << "# Using eta0=" << eta0 << endl;
292 		System.out.println("# Using eta0=" + eta0 + "\n");
293 	}
294 
295 	public void determineEta0(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
296 		final double factor = 2.0;
297 		double loEta = 1;
298 		double loCost = evaluateEta(imin, imax, xp, yp, loEta);
299 		double hiEta = loEta * factor;
300 		double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
301 		if (loCost < hiCost)
302 			while (loCost < hiCost) {
303 				hiEta = loEta;
304 				hiCost = loCost;
305 				loEta = hiEta / factor;
306 				loCost = evaluateEta(imin, imax, xp, yp, loEta);
307 			}
308 		else if (hiCost < loCost)
309 			while (hiCost < loCost) {
310 				loEta = hiEta;
311 				loCost = hiCost;
312 				hiEta = loEta * factor;
313 				hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
314 			}
315 		eta0 = loEta;
316 		// cout << "# Using eta0=" << eta0 << endl;
317 		System.out.println("# Using eta0=" + eta0 + "\n");
318 	}
319 
320 	public static void main(String[] args) {
321 		final MersenneTwister mt = new MersenneTwister();
322 		final SparseFloatArray[] tr = new SparseFloatArray[10000];
323 		final double[] clz = new double[tr.length];
324 		for (int i = 0; i < tr.length; i++) {
325 			tr[i] = new SparseHashedFloatArray(2);
326 
327 			if (i < tr.length / 2) {
328 				tr[i].set(0, (float) (mt.nextGaussian() - 2));
329 				tr[i].set(1, (float) (mt.nextGaussian() - 2));
330 				clz[i] = -1;
331 			} else {
332 				tr[i].set(0, (float) (mt.nextGaussian() + 2));
333 				tr[i].set(1, (float) (mt.nextGaussian() + 2));
334 				clz[i] = 1;
335 			}
336 			System.out.println(tr[i].values()[0] + " " + clz[i]);
337 		}
338 
339 		final SvmSgd svm = new SvmSgd(2, 1e-5);
340 		svm.BIAS = true;
341 		svm.REGULARIZED_BIAS = false;
342 		svm.determineEta0(0, tr.length - 1, tr, clz);
343 		for (int i = 0; i < 10; i++) {
344 			System.out.println();
345 			svm.train(0, tr.length - 1, tr, clz);
346 			svm.test(0, tr.length - 1, tr, clz, "training ");
347 			System.out.println(svm.w);
348 			System.out.println(svm.wBias);
349 			System.out.println(svm.wDivisor);
350 		}
351 
352 		// svm.w.values[0] = 1f;
353 		// svm.w.values[1] = 1f;
354 		// svm.wDivisor = 1;
355 		// svm.wBias = 0;
356 		// svm.test(0, 999, tr, clz, "training ");
357 	}
358 }