001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.sgdsvm;
031
032import java.io.IOException;
033import java.util.ArrayList;
034import java.util.List;
035
036import org.openimaj.time.Timer;
037import org.openimaj.util.array.SparseFloatArray;
038
039import gnu.trove.list.array.TDoubleArrayList;
040
041public class SvmSgdMain {
042        Loss LOSS = LossFunctions.LogLoss;
043        boolean BIAS = true;
044        boolean REGULARIZED_BIAS = false;
045
046        String trainfile = null;
047        String testfile = null;
048        boolean normalize = true;
049        double lambda = 1e-5;
050        int epochs = 5;
051        int maxtrain = -1;
052
053        String NAM(String s) {
054                return String.format("%16s ", s);
055        }
056
057        String DEF(Object v) {
058                return " (default: " + v + ".)";
059        }
060
061        void usage(String progname) {
062                System.err.println("Usage: " + progname + " [options] trainfile [testfile]");
063                System.err.println("Options:");
064
065                System.err.println(NAM("-lambda x") + "Regularization parameter" + DEF(lambda));
066                System.err.println(NAM("-epochs n") + "Number of training epochs" + DEF(epochs));
067                System.err.println(NAM("-dontnormalize") + "Do not normalize the L2 norm of patterns.");
068                System.err.println(NAM("-maxtrain n") + "Restrict training set to n examples.");
069                System.exit(10);
070        }
071
072        void
073                        parse(String[] args)
074        {
075                for (int i = 0; i < args.length; i++) {
076                        final String arg = args[i];
077                        if (arg.charAt(0) != '-') {
078                                if (trainfile == null)
079                                        trainfile = arg;
080                                else if (testfile == null)
081                                        testfile = arg;
082                                else
083                                        usage(this.getClass().getName());
084                        } else {
085                                // while (arg.charAt(0) == '-')
086                                // arg += 1;
087                                final String opt = arg;
088                                if (opt == "lambda" && i + 1 < args.length) {
089                                        lambda = Double.parseDouble(args[++i]);
090                                        assert (lambda > 0 && lambda < 1e4);
091                                } else if (opt == "epochs" && i + 1 < args.length) {
092                                        epochs = Integer.parseInt(args[++i]);
093                                        assert (epochs > 0 && epochs < 1e6);
094                                } else if (opt == "dontnormalize") {
095                                        normalize = false;
096                                } else if (opt == "maxtrain" && i + 1 < args.length) {
097                                        maxtrain = Integer.parseInt(args[++i]);
098                                        assert (maxtrain > 0);
099                                } else {
100                                        System.err.println("Option " + args[i] + " not recognized.");
101                                        usage(this.getClass().getName());
102                                }
103                        }
104                }
105                if (trainfile == null)
106                        usage(this.getClass().getName());
107        }
108
109        void
110                        config(String progname)
111        {
112                System.out.print("# Running: " + progname);
113                System.out.print(" -lambda " + lambda);
114                System.out.print(" -epochs " + epochs);
115                if (!normalize)
116                        System.out.print(" -dontnormalize");
117                if (maxtrain > 0)
118                        System.out.print(" -maxtrain " + maxtrain);
119                System.out.println();
120
121                System.out.print(
122                                "# Compiled with: " + "-DLOSS=" + LOSS + " -DBIAS=" + BIAS + "-DREGULARIZED_BIAS=" + REGULARIZED_BIAS);
123        }
124
125        // --- main function
126        int[] dims = { 0 };
127        List<SparseFloatArray> xtrain = new ArrayList<>();
128        TDoubleArrayList ytrain = new TDoubleArrayList();
129        List<SparseFloatArray> xtest = new ArrayList<>();
130        TDoubleArrayList ytest = new TDoubleArrayList();
131
132        public static void main(String[] args) throws IOException {
133                final SvmSgdMain main = new SvmSgdMain();
134                main.run(args);
135        }
136
137        void run(String[] args) throws IOException {
138                parse(args);
139                config(this.getClass().getName());
140                if (trainfile != null)
141                        load_datafile(trainfile, xtrain, ytrain, dims, normalize, maxtrain);
142                if (testfile != null)
143                        load_datafile(testfile, xtest, ytest, dims, normalize);
144                System.out.println("# Number of features " + dims + ".");
145
146                // prepare svm
147                final int imin = 0;
148                final int imax = xtrain.size() - 1;
149                final int tmin = 0;
150                final int tmax = xtest.size() - 1;
151
152                final SvmSgd svm = new SvmSgd(dims[0], lambda);
153                svm.BIAS = BIAS;
154                svm.LOSS = LOSS;
155                svm.REGULARIZED_BIAS = REGULARIZED_BIAS;
156
157                final Timer timer = new Timer();
158                // determine eta0 using sample
159                final int smin = 0;
160                final int smax = imin + Math.min(1000, imax);
161                timer.start();
162                svm.determineEta0(smin, smax, xtrain, ytrain);
163                timer.stop();
164                // train
165                for (int i = 0; i < epochs; i++) {
166                        System.out.println("--------- Epoch " + i + 1 + ".");
167                        timer.start();
168                        svm.train(imin, imax, xtrain, ytrain);
169                        timer.stop();
170                        System.out.println("Total training time " + timer.duration() / 1000 + "secs.");
171                        svm.test(imin, imax, xtrain, ytrain, "train:");
172                        if (tmax >= tmin)
173                                svm.test(tmin, tmax, xtest, ytest, "test: ");
174                }
175        }
176
177        private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
178                        boolean normalize) throws IOException
179        {
180                return load_datafile(file, x, y, dims, normalize, -1);
181        }
182
183        private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
184                        boolean normalize, int maxn) throws IOException
185        {
186                final Loader loader = new Loader(file);
187
188                final int[] maxdim = { 0 };
189                final int[] pcount = { 0 }, ncount = { 0 };
190                loader.load(x, y, normalize, maxn, maxdim, pcount, ncount);
191                if (pcount[0] + ncount[0] > 0)
192                        System.out.println("# Read " + pcount + "+" + ncount
193                                        + "=" + pcount + ncount + " examples "
194                                        + "from \"" + file + "\".");
195                if (dims[0] < maxdim[0])
196                        dims[0] = maxdim[0];
197                return pcount[0] + ncount[0];
198        }
199
200}