001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.workinprogress.sgdsvm; 031 032import java.util.List; 033 034import org.apache.commons.math.random.MersenneTwister; 035import org.openimaj.feature.FloatFV; 036import org.openimaj.feature.FloatFVComparison; 037import org.openimaj.util.array.ArrayUtils; 038import org.openimaj.util.array.SparseFloatArray; 039import org.openimaj.util.array.SparseFloatArray.Entry; 040import org.openimaj.util.array.SparseHashedFloatArray; 041 042import gnu.trove.list.array.TDoubleArrayList; 043 044public class SvmSgd implements Cloneable { 045 Loss LOSS = LossFunctions.HingeLoss; 046 boolean BIAS = true; 047 boolean REGULARIZED_BIAS = false; 048 049 public double lambda; 050 public double eta0; 051 FloatFV w; 052 double wDivisor; 053 double wBias; 054 double t; 055 056 public SvmSgd(int dim, double lambda) { 057 this(dim, lambda, 0); 058 } 059 060 public SvmSgd(int dim, double lambda, double eta0) { 061 this.lambda = lambda; 062 this.eta0 = eta0; 063 this.w = new FloatFV(dim); 064 this.wDivisor = 1; 065 this.wBias = 0; 066 this.t = 0; 067 } 068 069 private double dot(FloatFV v1, SparseFloatArray v2) { 070 double d = 0; 071 for (final Entry e : v2.entries()) { 072 d += e.value * v1.values[e.index]; 073 } 074 075 return d; 076 } 077 078 private double dot(FloatFV v1, FloatFV v2) { 079 return FloatFVComparison.INNER_PRODUCT.compare(v1, v2); 080 } 081 082 private void add(FloatFV y, SparseFloatArray x, double d) { 083 // w2 = w2 + x*w1 084 085 for (final Entry e : x.entries()) { 086 y.values[e.index] += e.value * d; 087 } 088 } 089 090 /// Renormalize the weights 091 public void renorm() { 092 if (wDivisor != 1.0) { 093 ArrayUtils.multiply(w.values, (float) (1.0 / wDivisor)); 094 // w.scale(1.0 / wDivisor); 095 wDivisor = 1.0; 096 } 097 } 098 099 /// Compute the norm of the weights 100 public double wnorm() { 101 double norm = dot(w, w) / wDivisor / wDivisor; 102 103 if (REGULARIZED_BIAS) 104 norm += wBias * wBias; 105 return norm; 106 } 107 108 /// Compute the output for one example. 109 public double testOne(final SparseFloatArray x, double y, double[] ploss, double[] pnerr) { 110 final double s = dot(w, x) / wDivisor + wBias; 111 if (ploss != null) 112 ploss[0] += LOSS.loss(s, y); 113 if (pnerr != null) 114 pnerr[0] += (s * y <= 0) ? 1 : 0; 115 return s; 116 } 117 118 /// Perform one iteration of the SGD algorithm with specified gains 119 public void trainOne(final SparseFloatArray x, double y, double eta) { 120 final double s = dot(w, x) / wDivisor + wBias; 121 // update for regularization term 122 wDivisor = wDivisor / (1 - eta * lambda); 123 if (wDivisor > 1e5) 124 renorm(); 125 // update for loss term 126 final double d = LOSS.dloss(s, y); 127 if (d != 0) 128 add(w, x, eta * d * wDivisor); 129 130 // same for the bias 131 if (BIAS) { 132 final double etab = eta * 0.01; 133 if (REGULARIZED_BIAS) { 134 wBias *= (1 - etab * lambda); 135 } 136 wBias += etab * d; 137 } 138 } 139 140 @Override 141 protected SvmSgd clone() { 142 SvmSgd clone; 143 try { 144 clone = (SvmSgd) super.clone(); 145 } catch (final CloneNotSupportedException e) { 146 throw new RuntimeException(e); 147 } 148 clone.w = clone.w.clone(); 149 return clone; 150 } 151 152 /// Perform a training epoch 153 public void train(int imin, int imax, SparseFloatArray[] xp, double[] yp) { 154 System.out.println("Training on [" + imin + ", " + imax + "]."); 155 assert (imin <= imax); 156 assert (eta0 > 0); 157 for (int i = imin; i <= imax; i++) { 158 final double eta = eta0 / (1 + lambda * eta0 * t); 159 trainOne(xp[i], yp[i], eta); 160 t += 1; 161 } 162 // cout << prefix << setprecision(6) << "wNorm=" << wnorm(); 163 System.out.format("wNorm=%.6f", wnorm()); 164 if (BIAS) { 165 // cout << " wBias=" << wBias; 166 System.out.format(" wBias=%.6f", wBias); 167 } 168 System.out.println(); 169 // cout << endl; 170 } 171 172 /// Perform a training epoch 173 public void train(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) { 174 System.out.println("Training on [" + imin + ", " + imax + "]."); 175 assert (imin <= imax); 176 assert (eta0 > 0); 177 for (int i = imin; i <= imax; i++) { 178 final double eta = eta0 / (1 + lambda * eta0 * t); 179 trainOne(xp.get(i), yp.get(i), eta); 180 t += 1; 181 } 182 // cout << prefix << setprecision(6) << "wNorm=" << wnorm(); 183 System.out.format("wNorm=%.6f", wnorm()); 184 if (BIAS) { 185 // cout << " wBias=" << wBias; 186 System.out.format(" wBias=%.6f", wBias); 187 } 188 System.out.println(); 189 // cout << endl; 190 } 191 192 /// Perform a test pass 193 public void test(int imin, int imax, SparseFloatArray[] xp, double[] yp, String prefix) { 194 // cout << prefix << "Testing on [" << imin << ", " << imax << "]." << 195 // endl; 196 System.out.println(prefix + "Testing on [" + imin + ", " + imax + "]."); 197 assert (imin <= imax); 198 final double nerr[] = { 0 }; 199 final double loss[] = { 0 }; 200 for (int i = imin; i <= imax; i++) 201 testOne(xp[i], yp[i], loss, nerr); 202 nerr[0] = nerr[0] / (imax - imin + 1); 203 loss[0] = loss[0] / (imax - imin + 1); 204 final double cost = loss[0] + 0.5 * lambda * wnorm(); 205 // cout << prefix 206 // << "Loss=" << setprecision(12) << loss 207 // << " Cost=" << setprecision(12) << cost 208 // << " Misclassification=" << setprecision(4) << 100 * nerr << "%." 209 // << endl; 210 System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification=" 211 + String.format("%2.4f", 100 * nerr[0]) + "%"); 212 } 213 214 /// Perform a test pass 215 public void test(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, String prefix) { 216 // cout << prefix << "Testing on [" << imin << ", " << imax << "]." << 217 // endl; 218 System.out.println(prefix + "Testing on [" + imin + ", " + imax + "]."); 219 assert (imin <= imax); 220 final double nerr[] = { 0 }; 221 final double loss[] = { 0 }; 222 for (int i = imin; i <= imax; i++) 223 testOne(xp.get(i), yp.get(i), loss, nerr); 224 nerr[0] = nerr[0] / (imax - imin + 1); 225 loss[0] = loss[0] / (imax - imin + 1); 226 final double cost = loss[0] + 0.5 * lambda * wnorm(); 227 // cout << prefix 228 // << "Loss=" << setprecision(12) << loss 229 // << " Cost=" << setprecision(12) << cost 230 // << " Misclassification=" << setprecision(4) << 100 * nerr << "%." 231 // << endl; 232 System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification=" 233 + String.format("%2.4f", 100 * nerr[0]) + "%"); 234 } 235 236 /// Perform one epoch with fixed eta and return cost 237 public double evaluateEta(int imin, int imax, SparseFloatArray[] xp, double[] yp, double eta) { 238 final SvmSgd clone = this.clone(); // take a copy of the current state 239 assert (imin <= imax); 240 for (int i = imin; i <= imax; i++) 241 clone.trainOne(xp[i], yp[i], eta); 242 final double loss[] = { 0 }; 243 double cost = 0; 244 for (int i = imin; i <= imax; i++) 245 clone.testOne(xp[i], yp[i], loss, null); 246 loss[0] = loss[0] / (imax - imin + 1); 247 cost = loss[0] + 0.5 * lambda * clone.wnorm(); 248 // cout << "Trying eta=" << eta << " yields cost " << cost << endl; 249 System.out.println("Trying eta=" + eta + " yields cost " + cost); 250 return cost; 251 } 252 253 /// Perform one epoch with fixed eta and return cost 254 public double evaluateEta(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, double eta) { 255 final SvmSgd clone = this.clone(); // take a copy of the current state 256 assert (imin <= imax); 257 for (int i = imin; i <= imax; i++) 258 clone.trainOne(xp.get(i), yp.get(i), eta); 259 final double loss[] = { 0 }; 260 double cost = 0; 261 for (int i = imin; i <= imax; i++) 262 clone.testOne(xp.get(i), yp.get(i), loss, null); 263 loss[0] = loss[0] / (imax - imin + 1); 264 cost = loss[0] + 0.5 * lambda * clone.wnorm(); 265 // cout << "Trying eta=" << eta << " yields cost " << cost << endl; 266 System.out.println("Trying eta=" + eta + " yields cost " + cost); 267 return cost; 268 } 269 270 public void determineEta0(int imin, int imax, SparseFloatArray[] xp, double[] yp) { 271 final double factor = 2.0; 272 double loEta = 1; 273 double loCost = evaluateEta(imin, imax, xp, yp, loEta); 274 double hiEta = loEta * factor; 275 double hiCost = evaluateEta(imin, imax, xp, yp, hiEta); 276 if (loCost < hiCost) 277 while (loCost < hiCost) { 278 hiEta = loEta; 279 hiCost = loCost; 280 loEta = hiEta / factor; 281 loCost = evaluateEta(imin, imax, xp, yp, loEta); 282 } 283 else if (hiCost < loCost) 284 while (hiCost < loCost) { 285 loEta = hiEta; 286 loCost = hiCost; 287 hiEta = loEta * factor; 288 hiCost = evaluateEta(imin, imax, xp, yp, hiEta); 289 } 290 eta0 = loEta; 291 // cout << "# Using eta0=" << eta0 << endl; 292 System.out.println("# Using eta0=" + eta0 + "\n"); 293 } 294 295 public void determineEta0(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) { 296 final double factor = 2.0; 297 double loEta = 1; 298 double loCost = evaluateEta(imin, imax, xp, yp, loEta); 299 double hiEta = loEta * factor; 300 double hiCost = evaluateEta(imin, imax, xp, yp, hiEta); 301 if (loCost < hiCost) 302 while (loCost < hiCost) { 303 hiEta = loEta; 304 hiCost = loCost; 305 loEta = hiEta / factor; 306 loCost = evaluateEta(imin, imax, xp, yp, loEta); 307 } 308 else if (hiCost < loCost) 309 while (hiCost < loCost) { 310 loEta = hiEta; 311 loCost = hiCost; 312 hiEta = loEta * factor; 313 hiCost = evaluateEta(imin, imax, xp, yp, hiEta); 314 } 315 eta0 = loEta; 316 // cout << "# Using eta0=" << eta0 << endl; 317 System.out.println("# Using eta0=" + eta0 + "\n"); 318 } 319 320 public static void main(String[] args) { 321 final MersenneTwister mt = new MersenneTwister(); 322 final SparseFloatArray[] tr = new SparseFloatArray[10000]; 323 final double[] clz = new double[tr.length]; 324 for (int i = 0; i < tr.length; i++) { 325 tr[i] = new SparseHashedFloatArray(2); 326 327 if (i < tr.length / 2) { 328 tr[i].set(0, (float) (mt.nextGaussian() - 2)); 329 tr[i].set(1, (float) (mt.nextGaussian() - 2)); 330 clz[i] = -1; 331 } else { 332 tr[i].set(0, (float) (mt.nextGaussian() + 2)); 333 tr[i].set(1, (float) (mt.nextGaussian() + 2)); 334 clz[i] = 1; 335 } 336 System.out.println(tr[i].values()[0] + " " + clz[i]); 337 } 338 339 final SvmSgd svm = new SvmSgd(2, 1e-5); 340 svm.BIAS = true; 341 svm.REGULARIZED_BIAS = false; 342 svm.determineEta0(0, tr.length - 1, tr, clz); 343 for (int i = 0; i < 10; i++) { 344 System.out.println(); 345 svm.train(0, tr.length - 1, tr, clz); 346 svm.test(0, tr.length - 1, tr, clz, "training "); 347 System.out.println(svm.w); 348 System.out.println(svm.wBias); 349 System.out.println(svm.wDivisor); 350 } 351 352 // svm.w.values[0] = 1f; 353 // svm.w.values[1] = 1f; 354 // svm.wDivisor = 1; 355 // svm.wBias = 0; 356 // svm.test(0, 999, tr, clz, "training "); 357 } 358}