001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.workinprogress.sgdsvm; 031 032import java.io.IOException; 033import java.util.ArrayList; 034import java.util.List; 035 036import org.openimaj.time.Timer; 037import org.openimaj.util.array.SparseFloatArray; 038 039import gnu.trove.list.array.TDoubleArrayList; 040 041public class SvmSgdMain { 042 Loss LOSS = LossFunctions.LogLoss; 043 boolean BIAS = true; 044 boolean REGULARIZED_BIAS = false; 045 046 String trainfile = null; 047 String testfile = null; 048 boolean normalize = true; 049 double lambda = 1e-5; 050 int epochs = 5; 051 int maxtrain = -1; 052 053 String NAM(String s) { 054 return String.format("%16s ", s); 055 } 056 057 String DEF(Object v) { 058 return " (default: " + v + ".)"; 059 } 060 061 void usage(String progname) { 062 System.err.println("Usage: " + progname + " [options] trainfile [testfile]"); 063 System.err.println("Options:"); 064 065 System.err.println(NAM("-lambda x") + "Regularization parameter" + DEF(lambda)); 066 System.err.println(NAM("-epochs n") + "Number of training epochs" + DEF(epochs)); 067 System.err.println(NAM("-dontnormalize") + "Do not normalize the L2 norm of patterns."); 068 System.err.println(NAM("-maxtrain n") + "Restrict training set to n examples."); 069 System.exit(10); 070 } 071 072 void 073 parse(String[] args) 074 { 075 for (int i = 0; i < args.length; i++) { 076 final String arg = args[i]; 077 if (arg.charAt(0) != '-') { 078 if (trainfile == null) 079 trainfile = arg; 080 else if (testfile == null) 081 testfile = arg; 082 else 083 usage(this.getClass().getName()); 084 } else { 085 // while (arg.charAt(0) == '-') 086 // arg += 1; 087 final String opt = arg; 088 if (opt == "lambda" && i + 1 < args.length) { 089 lambda = Double.parseDouble(args[++i]); 090 assert (lambda > 0 && lambda < 1e4); 091 } else if (opt == "epochs" && i + 1 < args.length) { 092 epochs = Integer.parseInt(args[++i]); 093 assert (epochs > 0 && epochs < 1e6); 094 } else if (opt == "dontnormalize") { 095 normalize = false; 096 } else if (opt == "maxtrain" && i + 1 < args.length) { 097 maxtrain = Integer.parseInt(args[++i]); 098 assert (maxtrain > 0); 099 } else { 100 System.err.println("Option " + args[i] + " not recognized."); 101 usage(this.getClass().getName()); 102 } 103 } 104 } 105 if (trainfile == null) 106 usage(this.getClass().getName()); 107 } 108 109 void 110 config(String progname) 111 { 112 System.out.print("# Running: " + progname); 113 System.out.print(" -lambda " + lambda); 114 System.out.print(" -epochs " + epochs); 115 if (!normalize) 116 System.out.print(" -dontnormalize"); 117 if (maxtrain > 0) 118 System.out.print(" -maxtrain " + maxtrain); 119 System.out.println(); 120 121 System.out.print( 122 "# Compiled with: " + "-DLOSS=" + LOSS + " -DBIAS=" + BIAS + "-DREGULARIZED_BIAS=" + REGULARIZED_BIAS); 123 } 124 125 // --- main function 126 int[] dims = { 0 }; 127 List<SparseFloatArray> xtrain = new ArrayList<>(); 128 TDoubleArrayList ytrain = new TDoubleArrayList(); 129 List<SparseFloatArray> xtest = new ArrayList<>(); 130 TDoubleArrayList ytest = new TDoubleArrayList(); 131 132 public static void main(String[] args) throws IOException { 133 final SvmSgdMain main = new SvmSgdMain(); 134 main.run(args); 135 } 136 137 void run(String[] args) throws IOException { 138 parse(args); 139 config(this.getClass().getName()); 140 if (trainfile != null) 141 load_datafile(trainfile, xtrain, ytrain, dims, normalize, maxtrain); 142 if (testfile != null) 143 load_datafile(testfile, xtest, ytest, dims, normalize); 144 System.out.println("# Number of features " + dims + "."); 145 146 // prepare svm 147 final int imin = 0; 148 final int imax = xtrain.size() - 1; 149 final int tmin = 0; 150 final int tmax = xtest.size() - 1; 151 152 final SvmSgd svm = new SvmSgd(dims[0], lambda); 153 svm.BIAS = BIAS; 154 svm.LOSS = LOSS; 155 svm.REGULARIZED_BIAS = REGULARIZED_BIAS; 156 157 final Timer timer = new Timer(); 158 // determine eta0 using sample 159 final int smin = 0; 160 final int smax = imin + Math.min(1000, imax); 161 timer.start(); 162 svm.determineEta0(smin, smax, xtrain, ytrain); 163 timer.stop(); 164 // train 165 for (int i = 0; i < epochs; i++) { 166 System.out.println("--------- Epoch " + i + 1 + "."); 167 timer.start(); 168 svm.train(imin, imax, xtrain, ytrain); 169 timer.stop(); 170 System.out.println("Total training time " + timer.duration() / 1000 + "secs."); 171 svm.test(imin, imax, xtrain, ytrain, "train:"); 172 if (tmax >= tmin) 173 svm.test(tmin, tmax, xtest, ytest, "test: "); 174 } 175 } 176 177 private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims, 178 boolean normalize) throws IOException 179 { 180 return load_datafile(file, x, y, dims, normalize, -1); 181 } 182 183 private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims, 184 boolean normalize, int maxn) throws IOException 185 { 186 final Loader loader = new Loader(file); 187 188 final int[] maxdim = { 0 }; 189 final int[] pcount = { 0 }, ncount = { 0 }; 190 loader.load(x, y, normalize, maxn, maxdim, pcount, ncount); 191 if (pcount[0] + ncount[0] > 0) 192 System.out.println("# Read " + pcount + "+" + ncount 193 + "=" + pcount + ncount + " examples " 194 + "from \"" + file + "\"."); 195 if (dims[0] < maxdim[0]) 196 dims[0] = maxdim[0]; 197 return pcount[0] + ncount[0]; 198 } 199 200}