1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.featlearn;
31
32 import java.io.File;
33 import java.io.IOException;
34 import java.util.Random;
35
36 import org.openimaj.feature.DoubleFV;
37 import org.openimaj.feature.FeatureExtractor;
38 import org.openimaj.image.DisplayUtilities;
39 import org.openimaj.image.FImage;
40 import org.openimaj.image.ImageUtilities;
41 import org.openimaj.image.processing.resize.ResizeProcessor;
42 import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
43 import org.openimaj.math.statistics.normalisation.PerExampleMeanCenter;
44 import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
45 import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
46 import org.openimaj.util.array.ArrayUtils;
47
48 public class TestImageClass implements FeatureExtractor<DoubleFV, FImage> {
49 final Random rng = new Random(0);
50 double[][] featurePatches;
51 FImage[] urbanPatches;
52 FImage[] ruralPatches;
53 int patchSize;
54 int bigPatchSize;
55
56 ZCAWhitening whitening = new ZCAWhitening(0.1, new PerExampleMeanCenter());
57 double[][] dictionary;
58 private double[][] whitenedFeaturePatches;
59
60 void extractFeaturePatches(FImage image, int npatches, int sz) {
61 patchSize = sz;
62 featurePatches = new double[npatches][];
63 for (int i = 0; i < npatches; i++) {
64 final int x = rng.nextInt(image.width - sz - 1);
65 final int y = rng.nextInt(image.height - sz - 1);
66
67 final double[] ip = image.extractROI(x, y, sz, sz).getDoublePixelVector();
68 featurePatches[i] = ip;
69 }
70 }
71
72 void extractClassifierTrainingPatches(FImage image, FImage labels, int npatchesPerClass, int sz) {
73 bigPatchSize = sz;
74 urbanPatches = new FImage[npatchesPerClass];
75 ruralPatches = new FImage[npatchesPerClass];
76
77 int u = 0;
78 int r = 0;
79
80 while (u < npatchesPerClass || r < npatchesPerClass) {
81 final int x = rng.nextInt(image.width - sz - 1);
82 final int y = rng.nextInt(image.height - sz - 1);
83
84 final FImage ip = image.extractROI(x, y, sz, sz);
85 final float[] lp = labels.extractROI(x, y, sz, sz).getFloatPixelVector();
86
87 boolean same = true;
88 for (int i = 0; i < sz * sz; i++) {
89 if (lp[i] != lp[0]) {
90 same = false;
91 break;
92 }
93 }
94
95 if (same) {
96 if (lp[0] == 0 && r < npatchesPerClass) {
97 ruralPatches[r] = ip;
98 r++;
99 } else if (lp[0] == 1 && u < npatchesPerClass) {
100
101
102 urbanPatches[u] = ip;
103 u++;
104 }
105 }
106 }
107 }
108
109 void learnDictionary(int dictSize) {
110 whitening.train(featurePatches);
111 whitenedFeaturePatches = whitening.whiten(featurePatches);
112
113 final SphericalKMeans skm = new SphericalKMeans(dictSize, 40);
114 final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches);
115 this.dictionary = res.centroids;
116 }
117
118 double[] representPatch(double[] patch) {
119 final double[] wp = whitening.whiten(patch);
120
121 final double[] z = new double[dictionary.length];
122 for (int i = 0; i < z.length; i++) {
123 double accum = 0;
124 for (int j = 0; j < patch.length; j++) {
125 accum += wp[j] * dictionary[i][j];
126 }
127
128 z[i] = Math.max(0, Math.abs(accum) - 0.5);
129 }
130 return z;
131 }
132
133 @Override
134 public DoubleFV extractFeature(FImage bigpatch) {
135 final double[][][] pfeatures = new double[3][3][dictionary.length];
136 final int[][] pcount = new int[3][3];
137
138 final FImage tmp = new FImage(patchSize, patchSize);
139 for (int y = 0; y < bigPatchSize - patchSize; y++) {
140 final int yp = (int) ((y / (double) (bigPatchSize - patchSize)) * 3);
141
142 for (int x = 0; x < bigPatchSize - patchSize; x++) {
143 final int xp = (int) ((x / (double) (bigPatchSize - patchSize)) * 3);
144
145 final double[] p = bigpatch.extractROI(x, y, tmp).getDoublePixelVector();
146 ArrayUtils.sum(pfeatures[yp][xp], representPatch(p));
147 pcount[yp][xp]++;
148
149 }
150 }
151
152 final double[] vector = new double[3 * 3 * dictionary.length];
153
154 for (int y = 0; y < 3; y++)
155 for (int x = 0; x < 3; x++)
156 for (int i = 0; i < dictionary.length; i++)
157 if (pfeatures[y][x][i] > 0)
158 vector[3 * x + y * 3 * 3 + i] = pfeatures[y][x][i] / pcount[y][x];
159
160 return new DoubleFV(vector);
161 }
162
163 public static void main(String[] args) throws IOException {
164 final TestImageClass tic = new TestImageClass();
165
166 final FImage trainPhoto = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
167 "/Users/jon/Desktop/images50cm4band/sp7034.jpeg"))));
168 final FImage trainClass = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
169 "/Users/jon/Desktop/images50cm4band/sp7034-classes.PNG"))));
170
171 tic.extractFeaturePatches(trainPhoto, 20000, 8);
172 tic.extractClassifierTrainingPatches(trainPhoto, trainClass, 1000, 32);
173 tic.learnDictionary(100);
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189 final FImage test = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
190 "/Users/jon/Desktop/images50cm4band/test.jpeg")))).normalise();
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206 final FImage tmp = new FImage(8 * 10, 8 * 10);
207 for (int i = 0, y = 0; y < 10; y++) {
208 for (int x = 0; x < 10; x++, i++) {
209 final FImage p = new FImage(tic.dictionary[i], 8, 8);
210 p.divideInplace(2 * Math.max(p.min(), p.max()));
211 p.addInplace(0.5f);
212 tmp.drawImage(p, x * 8, y * 8);
213 }
214 }
215 DisplayUtilities.display(tmp);
216 }
217 }