1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.featlearn;
31
32 import java.io.IOException;
33 import java.util.List;
34
35 import org.openimaj.image.DisplayUtilities;
36 import org.openimaj.image.MBFImage;
37 import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset;
38 import org.openimaj.image.colour.RGBColour;
39 import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
40 import org.openimaj.math.statistics.normalisation.PerExampleMeanCenter;
41 import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
42 import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
43
44 public class Test2 {
45 public static void main(String[] args) throws IOException {
46 System.out.println("start");
47 final RandomPatchSampler<MBFImage> sampler = new RandomPatchSampler<MBFImage>(
48 CIFAR10Dataset.getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER),
49 8, 8, 400000);
50 final List<MBFImage> patches = sampler.getPatches();
51 System.out.println("stop");
52
53 final double[][] data = new double[patches.size()][];
54 for (int i = 0; i < data.length; i++)
55 data[i] = patches.get(i).getDoublePixelVector();
56
57
58 final ZCAWhitening whitening = new ZCAWhitening(0.1, new PerExampleMeanCenter());
59 whitening.train(data);
60 final double[][] wd = whitening.whiten(data);
61
62 final SphericalKMeans skm = new SphericalKMeans(1600, 10);
63 final SphericalKMeansResult res = skm.cluster(wd);
64 final MBFImage tmp = new MBFImage(40 * (8 + 1) + 1, 40 * (8 + 1) + 1);
65 tmp.fill(RGBColour.WHITE);
66 for (int i = 0; i < 40; i++) {
67 for (int j = 0; j < 40; j++) {
68 final MBFImage patch = new MBFImage(res.centroids[i * 40 + j], 8, 8, 3, false);
69 tmp.drawImage(patch, i * (8 + 1) + 1, j * (8 + 1) + 1);
70 }
71 }
72 tmp.subtractInplace(-1.5f);
73 tmp.divideInplace(3f);
74 DisplayUtilities.display(tmp);
75 }
76 }