1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.sgdsvm;
31
32 import java.io.IOException;
33 import java.util.ArrayList;
34 import java.util.List;
35
36 import org.openimaj.time.Timer;
37 import org.openimaj.util.array.SparseFloatArray;
38
39 import gnu.trove.list.array.TDoubleArrayList;
40
41 public class SvmSgdMain {
42 Loss LOSS = LossFunctions.LogLoss;
43 boolean BIAS = true;
44 boolean REGULARIZED_BIAS = false;
45
46 String trainfile = null;
47 String testfile = null;
48 boolean normalize = true;
49 double lambda = 1e-5;
50 int epochs = 5;
51 int maxtrain = -1;
52
53 String NAM(String s) {
54 return String.format("%16s ", s);
55 }
56
57 String DEF(Object v) {
58 return " (default: " + v + ".)";
59 }
60
61 void usage(String progname) {
62 System.err.println("Usage: " + progname + " [options] trainfile [testfile]");
63 System.err.println("Options:");
64
65 System.err.println(NAM("-lambda x") + "Regularization parameter" + DEF(lambda));
66 System.err.println(NAM("-epochs n") + "Number of training epochs" + DEF(epochs));
67 System.err.println(NAM("-dontnormalize") + "Do not normalize the L2 norm of patterns.");
68 System.err.println(NAM("-maxtrain n") + "Restrict training set to n examples.");
69 System.exit(10);
70 }
71
72 void
73 parse(String[] args)
74 {
75 for (int i = 0; i < args.length; i++) {
76 final String arg = args[i];
77 if (arg.charAt(0) != '-') {
78 if (trainfile == null)
79 trainfile = arg;
80 else if (testfile == null)
81 testfile = arg;
82 else
83 usage(this.getClass().getName());
84 } else {
85
86
87 final String opt = arg;
88 if (opt == "lambda" && i + 1 < args.length) {
89 lambda = Double.parseDouble(args[++i]);
90 assert (lambda > 0 && lambda < 1e4);
91 } else if (opt == "epochs" && i + 1 < args.length) {
92 epochs = Integer.parseInt(args[++i]);
93 assert (epochs > 0 && epochs < 1e6);
94 } else if (opt == "dontnormalize") {
95 normalize = false;
96 } else if (opt == "maxtrain" && i + 1 < args.length) {
97 maxtrain = Integer.parseInt(args[++i]);
98 assert (maxtrain > 0);
99 } else {
100 System.err.println("Option " + args[i] + " not recognized.");
101 usage(this.getClass().getName());
102 }
103 }
104 }
105 if (trainfile == null)
106 usage(this.getClass().getName());
107 }
108
109 void
110 config(String progname)
111 {
112 System.out.print("# Running: " + progname);
113 System.out.print(" -lambda " + lambda);
114 System.out.print(" -epochs " + epochs);
115 if (!normalize)
116 System.out.print(" -dontnormalize");
117 if (maxtrain > 0)
118 System.out.print(" -maxtrain " + maxtrain);
119 System.out.println();
120
121 System.out.print(
122 "# Compiled with: " + "-DLOSS=" + LOSS + " -DBIAS=" + BIAS + "-DREGULARIZED_BIAS=" + REGULARIZED_BIAS);
123 }
124
125
126 int[] dims = { 0 };
127 List<SparseFloatArray> xtrain = new ArrayList<>();
128 TDoubleArrayList ytrain = new TDoubleArrayList();
129 List<SparseFloatArray> xtest = new ArrayList<>();
130 TDoubleArrayList ytest = new TDoubleArrayList();
131
132 public static void main(String[] args) throws IOException {
133 final SvmSgdMain main = new SvmSgdMain();
134 main.run(args);
135 }
136
137 void run(String[] args) throws IOException {
138 parse(args);
139 config(this.getClass().getName());
140 if (trainfile != null)
141 load_datafile(trainfile, xtrain, ytrain, dims, normalize, maxtrain);
142 if (testfile != null)
143 load_datafile(testfile, xtest, ytest, dims, normalize);
144 System.out.println("# Number of features " + dims + ".");
145
146
147 final int imin = 0;
148 final int imax = xtrain.size() - 1;
149 final int tmin = 0;
150 final int tmax = xtest.size() - 1;
151
152 final SvmSgd svm = new SvmSgd(dims[0], lambda);
153 svm.BIAS = BIAS;
154 svm.LOSS = LOSS;
155 svm.REGULARIZED_BIAS = REGULARIZED_BIAS;
156
157 final Timer timer = new Timer();
158
159 final int smin = 0;
160 final int smax = imin + Math.min(1000, imax);
161 timer.start();
162 svm.determineEta0(smin, smax, xtrain, ytrain);
163 timer.stop();
164
165 for (int i = 0; i < epochs; i++) {
166 System.out.println("--------- Epoch " + i + 1 + ".");
167 timer.start();
168 svm.train(imin, imax, xtrain, ytrain);
169 timer.stop();
170 System.out.println("Total training time " + timer.duration() / 1000 + "secs.");
171 svm.test(imin, imax, xtrain, ytrain, "train:");
172 if (tmax >= tmin)
173 svm.test(tmin, tmax, xtest, ytest, "test: ");
174 }
175 }
176
177 private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
178 boolean normalize) throws IOException
179 {
180 return load_datafile(file, x, y, dims, normalize, -1);
181 }
182
183 private static int load_datafile(String file, List<SparseFloatArray> x, TDoubleArrayList y, int[] dims,
184 boolean normalize, int maxn) throws IOException
185 {
186 final Loader loader = new Loader(file);
187
188 final int[] maxdim = { 0 };
189 final int[] pcount = { 0 }, ncount = { 0 };
190 loader.load(x, y, normalize, maxn, maxdim, pcount, ncount);
191 if (pcount[0] + ncount[0] > 0)
192 System.out.println("# Read " + pcount + "+" + ncount
193 + "=" + pcount + ncount + " examples "
194 + "from \"" + file + "\".");
195 if (dims[0] < maxdim[0])
196 dims[0] = maxdim[0];
197 return pcount[0] + ncount[0];
198 }
199
200 }