001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.sgdsvm;
031
032import java.util.List;
033
034import org.apache.commons.math.random.MersenneTwister;
035import org.openimaj.feature.FloatFV;
036import org.openimaj.feature.FloatFVComparison;
037import org.openimaj.util.array.ArrayUtils;
038import org.openimaj.util.array.SparseFloatArray;
039import org.openimaj.util.array.SparseFloatArray.Entry;
040import org.openimaj.util.array.SparseHashedFloatArray;
041
042import gnu.trove.list.array.TDoubleArrayList;
043
044public class SvmSgd implements Cloneable {
045        Loss LOSS = LossFunctions.HingeLoss;
046        boolean BIAS = true;
047        boolean REGULARIZED_BIAS = false;
048
049        public double lambda;
050        public double eta0;
051        FloatFV w;
052        double wDivisor;
053        double wBias;
054        double t;
055
056        public SvmSgd(int dim, double lambda) {
057                this(dim, lambda, 0);
058        }
059
060        public SvmSgd(int dim, double lambda, double eta0) {
061                this.lambda = lambda;
062                this.eta0 = eta0;
063                this.w = new FloatFV(dim);
064                this.wDivisor = 1;
065                this.wBias = 0;
066                this.t = 0;
067        }
068
069        private double dot(FloatFV v1, SparseFloatArray v2) {
070                double d = 0;
071                for (final Entry e : v2.entries()) {
072                        d += e.value * v1.values[e.index];
073                }
074
075                return d;
076        }
077
078        private double dot(FloatFV v1, FloatFV v2) {
079                return FloatFVComparison.INNER_PRODUCT.compare(v1, v2);
080        }
081
082        private void add(FloatFV y, SparseFloatArray x, double d) {
083                // w2 = w2 + x*w1
084
085                for (final Entry e : x.entries()) {
086                        y.values[e.index] += e.value * d;
087                }
088        }
089
090        /// Renormalize the weights
091        public void renorm() {
092                if (wDivisor != 1.0) {
093                        ArrayUtils.multiply(w.values, (float) (1.0 / wDivisor));
094                        // w.scale(1.0 / wDivisor);
095                        wDivisor = 1.0;
096                }
097        }
098
099        /// Compute the norm of the weights
100        public double wnorm() {
101                double norm = dot(w, w) / wDivisor / wDivisor;
102
103                if (REGULARIZED_BIAS)
104                        norm += wBias * wBias;
105                return norm;
106        }
107
108        /// Compute the output for one example.
109        public double testOne(final SparseFloatArray x, double y, double[] ploss, double[] pnerr) {
110                final double s = dot(w, x) / wDivisor + wBias;
111                if (ploss != null)
112                        ploss[0] += LOSS.loss(s, y);
113                if (pnerr != null)
114                        pnerr[0] += (s * y <= 0) ? 1 : 0;
115                return s;
116        }
117
118        /// Perform one iteration of the SGD algorithm with specified gains
119        public void trainOne(final SparseFloatArray x, double y, double eta) {
120                final double s = dot(w, x) / wDivisor + wBias;
121                // update for regularization term
122                wDivisor = wDivisor / (1 - eta * lambda);
123                if (wDivisor > 1e5)
124                        renorm();
125                // update for loss term
126                final double d = LOSS.dloss(s, y);
127                if (d != 0)
128                        add(w, x, eta * d * wDivisor);
129
130                // same for the bias
131                if (BIAS) {
132                        final double etab = eta * 0.01;
133                        if (REGULARIZED_BIAS) {
134                                wBias *= (1 - etab * lambda);
135                        }
136                        wBias += etab * d;
137                }
138        }
139
140        @Override
141        protected SvmSgd clone() {
142                SvmSgd clone;
143                try {
144                        clone = (SvmSgd) super.clone();
145                } catch (final CloneNotSupportedException e) {
146                        throw new RuntimeException(e);
147                }
148                clone.w = clone.w.clone();
149                return clone;
150        }
151
152        /// Perform a training epoch
153        public void train(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
154                System.out.println("Training on [" + imin + ", " + imax + "].");
155                assert (imin <= imax);
156                assert (eta0 > 0);
157                for (int i = imin; i <= imax; i++) {
158                        final double eta = eta0 / (1 + lambda * eta0 * t);
159                        trainOne(xp[i], yp[i], eta);
160                        t += 1;
161                }
162                // cout << prefix << setprecision(6) << "wNorm=" << wnorm();
163                System.out.format("wNorm=%.6f", wnorm());
164                if (BIAS) {
165                        // cout << " wBias=" << wBias;
166                        System.out.format(" wBias=%.6f", wBias);
167                }
168                System.out.println();
169                // cout << endl;
170        }
171
172        /// Perform a training epoch
173        public void train(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
174                System.out.println("Training on [" + imin + ", " + imax + "].");
175                assert (imin <= imax);
176                assert (eta0 > 0);
177                for (int i = imin; i <= imax; i++) {
178                        final double eta = eta0 / (1 + lambda * eta0 * t);
179                        trainOne(xp.get(i), yp.get(i), eta);
180                        t += 1;
181                }
182                // cout << prefix << setprecision(6) << "wNorm=" << wnorm();
183                System.out.format("wNorm=%.6f", wnorm());
184                if (BIAS) {
185                        // cout << " wBias=" << wBias;
186                        System.out.format(" wBias=%.6f", wBias);
187                }
188                System.out.println();
189                // cout << endl;
190        }
191
192        /// Perform a test pass
193        public void test(int imin, int imax, SparseFloatArray[] xp, double[] yp, String prefix) {
194                // cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
195                // endl;
196                System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
197                assert (imin <= imax);
198                final double nerr[] = { 0 };
199                final double loss[] = { 0 };
200                for (int i = imin; i <= imax; i++)
201                        testOne(xp[i], yp[i], loss, nerr);
202                nerr[0] = nerr[0] / (imax - imin + 1);
203                loss[0] = loss[0] / (imax - imin + 1);
204                final double cost = loss[0] + 0.5 * lambda * wnorm();
205                // cout << prefix
206                // << "Loss=" << setprecision(12) << loss
207                // << " Cost=" << setprecision(12) << cost
208                // << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
209                // << endl;
210                System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
211                                + String.format("%2.4f", 100 * nerr[0]) + "%");
212        }
213
214        /// Perform a test pass
215        public void test(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, String prefix) {
216                // cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
217                // endl;
218                System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
219                assert (imin <= imax);
220                final double nerr[] = { 0 };
221                final double loss[] = { 0 };
222                for (int i = imin; i <= imax; i++)
223                        testOne(xp.get(i), yp.get(i), loss, nerr);
224                nerr[0] = nerr[0] / (imax - imin + 1);
225                loss[0] = loss[0] / (imax - imin + 1);
226                final double cost = loss[0] + 0.5 * lambda * wnorm();
227                // cout << prefix
228                // << "Loss=" << setprecision(12) << loss
229                // << " Cost=" << setprecision(12) << cost
230                // << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
231                // << endl;
232                System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
233                                + String.format("%2.4f", 100 * nerr[0]) + "%");
234        }
235
236        /// Perform one epoch with fixed eta and return cost
237        public double evaluateEta(int imin, int imax, SparseFloatArray[] xp, double[] yp, double eta) {
238                final SvmSgd clone = this.clone(); // take a copy of the current state
239                assert (imin <= imax);
240                for (int i = imin; i <= imax; i++)
241                        clone.trainOne(xp[i], yp[i], eta);
242                final double loss[] = { 0 };
243                double cost = 0;
244                for (int i = imin; i <= imax; i++)
245                        clone.testOne(xp[i], yp[i], loss, null);
246                loss[0] = loss[0] / (imax - imin + 1);
247                cost = loss[0] + 0.5 * lambda * clone.wnorm();
248                // cout << "Trying eta=" << eta << " yields cost " << cost << endl;
249                System.out.println("Trying eta=" + eta + " yields cost " + cost);
250                return cost;
251        }
252
253        /// Perform one epoch with fixed eta and return cost
254        public double evaluateEta(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp, double eta) {
255                final SvmSgd clone = this.clone(); // take a copy of the current state
256                assert (imin <= imax);
257                for (int i = imin; i <= imax; i++)
258                        clone.trainOne(xp.get(i), yp.get(i), eta);
259                final double loss[] = { 0 };
260                double cost = 0;
261                for (int i = imin; i <= imax; i++)
262                        clone.testOne(xp.get(i), yp.get(i), loss, null);
263                loss[0] = loss[0] / (imax - imin + 1);
264                cost = loss[0] + 0.5 * lambda * clone.wnorm();
265                // cout << "Trying eta=" << eta << " yields cost " << cost << endl;
266                System.out.println("Trying eta=" + eta + " yields cost " + cost);
267                return cost;
268        }
269
270        public void determineEta0(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
271                final double factor = 2.0;
272                double loEta = 1;
273                double loCost = evaluateEta(imin, imax, xp, yp, loEta);
274                double hiEta = loEta * factor;
275                double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
276                if (loCost < hiCost)
277                        while (loCost < hiCost) {
278                                hiEta = loEta;
279                                hiCost = loCost;
280                                loEta = hiEta / factor;
281                                loCost = evaluateEta(imin, imax, xp, yp, loEta);
282                        }
283                else if (hiCost < loCost)
284                        while (hiCost < loCost) {
285                                loEta = hiEta;
286                                loCost = hiCost;
287                                hiEta = loEta * factor;
288                                hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
289                        }
290                eta0 = loEta;
291                // cout << "# Using eta0=" << eta0 << endl;
292                System.out.println("# Using eta0=" + eta0 + "\n");
293        }
294
295        public void determineEta0(int imin, int imax, List<SparseFloatArray> xp, TDoubleArrayList yp) {
296                final double factor = 2.0;
297                double loEta = 1;
298                double loCost = evaluateEta(imin, imax, xp, yp, loEta);
299                double hiEta = loEta * factor;
300                double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
301                if (loCost < hiCost)
302                        while (loCost < hiCost) {
303                                hiEta = loEta;
304                                hiCost = loCost;
305                                loEta = hiEta / factor;
306                                loCost = evaluateEta(imin, imax, xp, yp, loEta);
307                        }
308                else if (hiCost < loCost)
309                        while (hiCost < loCost) {
310                                loEta = hiEta;
311                                loCost = hiCost;
312                                hiEta = loEta * factor;
313                                hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
314                        }
315                eta0 = loEta;
316                // cout << "# Using eta0=" << eta0 << endl;
317                System.out.println("# Using eta0=" + eta0 + "\n");
318        }
319
320        public static void main(String[] args) {
321                final MersenneTwister mt = new MersenneTwister();
322                final SparseFloatArray[] tr = new SparseFloatArray[10000];
323                final double[] clz = new double[tr.length];
324                for (int i = 0; i < tr.length; i++) {
325                        tr[i] = new SparseHashedFloatArray(2);
326
327                        if (i < tr.length / 2) {
328                                tr[i].set(0, (float) (mt.nextGaussian() - 2));
329                                tr[i].set(1, (float) (mt.nextGaussian() - 2));
330                                clz[i] = -1;
331                        } else {
332                                tr[i].set(0, (float) (mt.nextGaussian() + 2));
333                                tr[i].set(1, (float) (mt.nextGaussian() + 2));
334                                clz[i] = 1;
335                        }
336                        System.out.println(tr[i].values()[0] + " " + clz[i]);
337                }
338
339                final SvmSgd svm = new SvmSgd(2, 1e-5);
340                svm.BIAS = true;
341                svm.REGULARIZED_BIAS = false;
342                svm.determineEta0(0, tr.length - 1, tr, clz);
343                for (int i = 0; i < 10; i++) {
344                        System.out.println();
345                        svm.train(0, tr.length - 1, tr, clz);
346                        svm.test(0, tr.length - 1, tr, clz, "training ");
347                        System.out.println(svm.w);
348                        System.out.println(svm.wBias);
349                        System.out.println(svm.wDivisor);
350                }
351
352                // svm.w.values[0] = 1f;
353                // svm.w.values[1] = 1f;
354                // svm.wDivisor = 1;
355                // svm.wBias = 0;
356                // svm.test(0, 999, tr, clz, "training ");
357        }
358}