View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress.sgdsvm;
31  
32  import java.io.IOException;
33  import java.util.ArrayList;
34  import java.util.List;
35  
36  import org.openimaj.time.Timer;
37  import org.openimaj.util.array.SparseFloatArray;
38  
39  import gnu.trove.list.array.TDoubleArrayList;
40  
41  public class SvmSgdMain {
42  	Loss LOSS = LossFunctions.LogLoss;
43  	boolean BIAS = true;
44  	boolean REGULARIZED_BIAS = false;
45  
46  	String trainfile = null;
47  	String testfile = null;
48  	boolean normalize = true;
49  	double lambda = 1e-5;
50  	int epochs = 5;
51  	int maxtrain = -1;
52  
53  	String NAM(String s) {
54  		return String.format("%16s ", s);
55  	}
56  
57  	String DEF(Object v) {
58  		return " (default: " + v + ".)";
59  	}
60  
61  	void usage(String progname) {
62  		System.err.println("Usage: " + progname + " [options] trainfile [testfile]");
63  		System.err.println("Options:");
64  
65  		System.err.println(NAM("-lambda x") + "Regularization parameter" + DEF(lambda));
66  		System.err.println(NAM("-epochs n") + "Number of training epochs" + DEF(epochs));
67  		System.err.println(NAM("-dontnormalize") + "Do not normalize the L2 norm of patterns.");
68  		System.err.println(NAM("-maxtrain n") + "Restrict training set to n examples.");
69  		System.exit(10);
70  	}
71  
72  	void
73  			parse(String[] args)
74  	{
75  		for (int i = 0; i < args.length; i++) {
76  			final String arg = args[i];
77  			if (arg.charAt(0) != '-') {
78  				if (trainfile == null)
79  					trainfile = arg;
80  				else if (testfile == null)
81  					testfile = arg;
82  				else
83  					usage(this.getClass().getName());
84  			} else {
85  				// while (arg.charAt(0) == '-')
86  				// arg += 1;
87  				final String opt = arg;
88  				if (opt == "lambda" && i + 1 < args.length) {
89  					lambda = Double.parseDouble(args[++i]);
90  					assert (lambda > 0 && lambda < 1e4);
91  				} else if (opt == "epochs" && i + 1 < args.length) {
92  					epochs = Integer.parseInt(args[++i]);
93  					assert (epochs > 0 && epochs < 1e6);
94  				} else if (opt == "dontnormalize") {
95  					normalize = false;
96  				} else if (opt == "maxtrain" && i + 1 < args.length) {
97  					maxtrain = Integer.parseInt(args[++i]);
98  					assert (maxtrain > 0);
99  				} else {
100 					System.err.println("Option " + args[i] + " not recognized.");
101 					usage(this.getClass().getName());
102 				}
103 			}
104 		}
105 		if (trainfile == null)
106 			usage(this.getClass().getName());
107 	}
108 
109 	void
110 			config(String progname)
111 	{
112 		System.out.print("# Running: " + progname);
113 		System.out.print(" -lambda " + lambda);
114 		System.out.print(" -epochs " + epochs);
115 		if (!normalize)
116 			System.out.print(" -dontnormalize");
117 		if (maxtrain > 0)
118 			System.out.print(" -maxtrain " + maxtrain);
119 		System.out.println();
120 
121 		System.out.print(
122 				"# Compiled with: " + "-DLOSS=" + LOSS + " -DBIAS=" + BIAS + "-DREGULARIZED_BIAS=" + REGULARIZED_BIAS);
123 	}
124 
125 	// --- main function
126 	int[] dims = { 0 };
127 	List<SparseFloatArray> xtrain = new ArrayList<>();
128 	TDoubleArrayList ytrain = new TDoubleArrayList();
129 	List<SparseFloatArray> xtest = new ArrayList<>();
130 	TDoubleArrayList ytest = new TDoubleArrayList();
131 
132 	public static void main(String[] args) throws IOException {
133 		final SvmSgdMain main = new SvmSgdMain();
134 		main.run(args);
135 	}
136 
137 	void run(String[] args) throws IOException {
138 		parse(args);
139 		config(this.getClass().getName());
140 		if (trainfile != null)
141 			load_datafile(trainfile, xtrain, ytrain, dims, normalize, maxtrain);
142 		if (testfile != null)
143 			load_datafile(testfile, xtest, ytest, dims, normalize);
144 		System.out.println("# Number of features " + dims + ".");
145 
146 		// prepare svm
147 		final int imin = 0;
148 		final int imax = xtrain.size() - 1;
149 		final int tmin = 0;
150 		final int tmax = xtest.size() - 1;
151 
152 		final SvmSgd svm = new SvmSgd(dims[0], lambda);
153 		svm.BIAS = BIAS;
154 		svm.LOSS = LOSS;
155 		svm.REGULARIZED_BIAS = REGULARIZED_BIAS;
156 
157 		final Timer timer = new Timer();
158 		// determine eta0 using sample
159 		final int smin = 0;
160 		final int smax = imin + Math.min(1000, imax);
161 		timer.start();
162 		svm.determineEta0(smin, smax, xtrain, ytrain);
163 		timer.stop();
164 		// train
165 		for (int i = 0; i < epochs; i++) {
166 			System.out.println("--------- Epoch " + i + 1 + ".");
167 			timer.start();
168 			svm.train(imin, imax, xtrain, ytrain);
169 			timer.stop();
170 			System.out.println("Total training time " + timer.duration() / 1000 + "secs.");
171 			svm.test(imin, imax, xtrain, ytrain, "train:");
172 			if (tmax >= tmin)
173 				svm.test(tmin, tmax, xtest, ytest, "test: ");
174 		}
175 	}
176 
177 	private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
178 			boolean normalize) throws IOException
179 	{
180 		return load_datafile(file, x, y, dims, normalize, -1);
181 	}
182 
183 	private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
184 			boolean normalize, int maxn) throws IOException
185 	{
186 		final Loader loader = new Loader(file);
187 
188 		final int[] maxdim = { 0 };
189 		final int[] pcount = { 0 }, ncount = { 0 };
190 		loader.load(x, y, normalize, maxn, maxdim, pcount, ncount);
191 		if (pcount[0] + ncount[0] > 0)
192 			System.out.println("# Read " + pcount + "+" + ncount
193 					+ "=" + pcount + ncount + " examples "
194 					+ "from \"" + file + "\".");
195 		if (dims[0] < maxdim[0])
196 			dims[0] = maxdim[0];
197 		return pcount[0] + ncount[0];
198 	}
199 
200 }