1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.optimisation;
31
32 import java.util.Random;
33
34 import org.openimaj.data.DataSource;
35 import org.openimaj.data.DoubleArrayBackedDataSource;
36 import org.openimaj.workinprogress.optimisation.params.Parameters;
37 import org.openimaj.workinprogress.optimisation.params.VectorParameters;
38
39 import scala.actors.threadpool.Arrays;
40
41 public class SGD<MODEL, DATATYPE, PTYPE extends Parameters<PTYPE>> {
42 public int maxEpochs = 100;
43 public int batchSize = 1;
44 public LearningRate<PTYPE> learningRate;
45 public MODEL model;
46 public DifferentiableObjectiveFunction<MODEL, DATATYPE, PTYPE> fcn;
47
48 public void train(DataSource<DATATYPE> data) {
49 final DATATYPE[] batch = data.createTemporaryArray(batchSize);
50
51 for (int e = 0; e < maxEpochs; e++) {
52 for (int i = 0; i < data.size(); i += batchSize) {
53 final int currentBatchStop = Math.min(data.size(), i + batchSize);
54 final int currentBatchSize = currentBatchStop - i;
55 data.getData(i, currentBatchStop, batch);
56
57 final PTYPE grads = fcn.derivative(model, batch[0]);
58 for (int j = 1; j < currentBatchSize; j++) {
59 grads.addInplace(fcn.derivative(model, batch[j]));
60 }
61 grads.multiplyInplace(learningRate.getRate(e, i, grads));
62 fcn.updateModel(model, grads);
63 }
64 }
65 }
66
67 public double value(MODEL model, DATATYPE data) {
68 return 0;
69 }
70
71 public static void main(String[] args) {
72 final double[][] data = new double[1000][2];
73 final Random rng = new Random();
74 for (int i = 0; i < data.length; i++) {
75 final double x = rng.nextDouble();
76 data[i][0] = x;
77 data[i][1] = 0.3 * x + 20 + (rng.nextGaussian() * 0.01);
78 }
79 final DoubleArrayBackedDataSource ds = new DoubleArrayBackedDataSource(data);
80
81 final double[] model = { 0, 0 };
82
83 final DifferentiableObjectiveFunction<double[], double[], VectorParameters> fcn = new DifferentiableObjectiveFunction<double[], double[], VectorParameters>()
84 {
85 @Override
86 public double value(double[] model, double[] data) {
87 final double diff = data[1] - (model[0] * data[0] + model[1]);
88 return diff * diff;
89 }
90
91 @Override
92 public VectorParameters derivative(double[] model, double[] data) {
93 final double[] der = {
94 2 * data[0] * (-data[1] + model[0] * data[0] + model[1]),
95 2 * (-data[1] + model[0] * data[0] + model[1])
96 };
97
98 return new VectorParameters(der);
99 }
100
101 @Override
102 public void updateModel(double[] model, VectorParameters weights) {
103 model[0] -= weights.vector[0];
104 model[1] -= weights.vector[1];
105 }
106 };
107
108 final SGD<double[], double[], VectorParameters> sgd = new SGD<double[], double[], VectorParameters>();
109 sgd.model = model;
110 sgd.fcn = fcn;
111 sgd.learningRate = new StaticLearningRate<VectorParameters>(0.01);
112 sgd.batchSize = 1;
113 sgd.maxEpochs = 10;
114
115 sgd.train(ds);
116
117 System.out.println(Arrays.toString(model));
118 }
119 }