1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.featlearn.cifarexps;
31
32 import java.io.IOException;
33 import java.util.List;
34
35 import org.openimaj.image.DisplayUtilities;
36 import org.openimaj.image.MBFImage;
37 import org.openimaj.image.pixel.sampling.RectangleSampler;
38 import org.openimaj.math.geometry.shape.Rectangle;
39 import org.openimaj.math.matrix.algorithm.whitening.WhiteningTransform;
40 import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
41 import org.openimaj.math.statistics.normalisation.Normaliser;
42 import org.openimaj.math.statistics.normalisation.PerExampleMeanCenterVar;
43 import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
44 import org.openimaj.ml.clustering.kmeans.SphericalKMeans.IterationResult;
45 import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
46 import org.openimaj.util.function.Operation;
47
48 public class KMeansExp1 extends CIFARExperimentFramework {
49 Normaliser patchNorm = new PerExampleMeanCenterVar(10.0 / 255.0);
50 WhiteningTransform whitening = new ZCAWhitening(0.1, patchNorm);
51 int numCentroids = 1600;
52 int numIters = 10;
53
54 private double[][] dictionary;
55 final RectangleSampler rs = new RectangleSampler(new Rectangle(0, 0, 32, 32), 1, 1, patchSize, patchSize);
56 final List<Rectangle> rectangles = rs.allRectangles();
57
58 @Override
59 protected void learnFeatures(double[][] patches) {
60 whitening.train(patches);
61
62 final double[][] whitenedFeaturePatches = whitening.whiten(patches);
63 final SphericalKMeans skm = new SphericalKMeans(numCentroids, numIters);
64 skm.addIterationListener(new Operation<SphericalKMeans.IterationResult>() {
65 @Override
66 public void perform(IterationResult object) {
67 System.out.println("KMeans iteration " + object.iteration + " / " + numIters);
68 DisplayUtilities.display(drawCentroids(object.result.centroids));
69 }
70 });
71 final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches);
72 this.dictionary = res.centroids;
73
74 DisplayUtilities.display(drawCentroids(this.dictionary));
75 }
76
77 MBFImage drawCentroids(double[][] centroids) {
78 final int wh = (int) Math.sqrt(numCentroids);
79 final MBFImage tmp = new MBFImage(wh * (patchSize + 1) + 1, wh * (patchSize + 1) + 1);
80 final float mn = -1.0f;
81 final float mx = +1.0f;
82 tmp.fill(new Float[] { mx, mx, mx });
83
84 for (int i = 0, y = 0; y < wh; y++) {
85 for (int x = 0; x < wh; x++, i++) {
86 final MBFImage p = new MBFImage(centroids[i], patchSize, patchSize, 3, false);
87 tmp.drawImage(p, x * (patchSize + 1) + 1, y * (patchSize + 1) + 1);
88 }
89 }
90 tmp.subtractInplace(mn);
91 tmp.divideInplace(mx - mn);
92 return tmp;
93 }
94
95 @Override
96 protected double[] extractFeatures(MBFImage image) {
97 double[][] patches = new double[rectangles.size()][];
98 final MBFImage tmpImage = new MBFImage(this.patchSize, this.patchSize);
99
100 for (int i = 0; i < patches.length; i++) {
101 final Rectangle r = rectangles.get(i);
102 patches[i] = image.extractROI((int) r.x, (int) r.y, tmpImage).getDoublePixelVector();
103 }
104 patches = whitening.whiten(patches);
105 patches = activation(patches);
106
107
108 final double[] feature = pool(patches);
109
110 return feature;
111 }
112
113 private double[] pool(double[][] patches) {
114 final double[] feature = new double[dictionary.length * 4];
115 final int sz = (int) Math.sqrt(patches.length);
116 final int hsz = sz / 2;
117 for (int j = 0; j < sz; j++) {
118 final int by = j < hsz ? 0 : 1;
119 for (int i = 0; i < sz; i++) {
120 final int bx = i < hsz ? 0 : 1;
121
122 final double[] p = patches[j * sz + i];
123 for (int k = 0; k < p.length; k++)
124 feature[2 * dictionary.length * by + dictionary.length * bx + k] += p[k];
125 }
126 }
127 return feature;
128 }
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 private double[][] activation(double[][] p) {
159 final double[][] c = this.dictionary;
160 final double[][] result = new double[p.length][c.length];
161
162 for (int i = 0; i < p.length; i++) {
163 final double[] x = p[i];
164
165 for (int k = 0; k < c.length; k++) {
166 double dx = 0;
167 for (int j = 0; j < x.length; j++) {
168 dx += c[k][j] * x[j];
169 }
170 result[i][k] = Math.max(0, Math.abs(dx) - 0.5);
171 }
172 }
173
174 return result;
175 }
176
177 public static void main(String[] args) throws IOException {
178 new KMeansExp1().run();
179 }
180 }