View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress.featlearn;
31  
32  import java.io.File;
33  import java.io.IOException;
34  import java.util.Random;
35  
36  import org.openimaj.feature.DoubleFV;
37  import org.openimaj.feature.FeatureExtractor;
38  import org.openimaj.image.DisplayUtilities;
39  import org.openimaj.image.FImage;
40  import org.openimaj.image.ImageUtilities;
41  import org.openimaj.image.processing.resize.ResizeProcessor;
42  import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
43  import org.openimaj.math.statistics.normalisation.PerExampleMeanCenter;
44  import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
45  import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
46  import org.openimaj.util.array.ArrayUtils;
47  
48  public class TestImageClass implements FeatureExtractor<DoubleFV, FImage> {
49  	final Random rng = new Random(0);
50  	double[][] featurePatches;
51  	FImage[] urbanPatches;
52  	FImage[] ruralPatches;
53  	int patchSize;
54  	int bigPatchSize;
55  
56  	ZCAWhitening whitening = new ZCAWhitening(0.1, new PerExampleMeanCenter());
57  	double[][] dictionary;
58  	private double[][] whitenedFeaturePatches;
59  
60  	void extractFeaturePatches(FImage image, int npatches, int sz) {
61  		patchSize = sz;
62  		featurePatches = new double[npatches][];
63  		for (int i = 0; i < npatches; i++) {
64  			final int x = rng.nextInt(image.width - sz - 1);
65  			final int y = rng.nextInt(image.height - sz - 1);
66  
67  			final double[] ip = image.extractROI(x, y, sz, sz).getDoublePixelVector();
68  			featurePatches[i] = ip;
69  		}
70  	}
71  
72  	void extractClassifierTrainingPatches(FImage image, FImage labels, int npatchesPerClass, int sz) {
73  		bigPatchSize = sz;
74  		urbanPatches = new FImage[npatchesPerClass];
75  		ruralPatches = new FImage[npatchesPerClass];
76  
77  		int u = 0;
78  		int r = 0;
79  
80  		while (u < npatchesPerClass || r < npatchesPerClass) {
81  			final int x = rng.nextInt(image.width - sz - 1);
82  			final int y = rng.nextInt(image.height - sz - 1);
83  
84  			final FImage ip = image.extractROI(x, y, sz, sz);
85  			final float[] lp = labels.extractROI(x, y, sz, sz).getFloatPixelVector();
86  
87  			boolean same = true;
88  			for (int i = 0; i < sz * sz; i++) {
89  				if (lp[i] != lp[0]) {
90  					same = false;
91  					break;
92  				}
93  			}
94  
95  			if (same) {
96  				if (lp[0] == 0 && r < npatchesPerClass) {
97  					ruralPatches[r] = ip;
98  					r++;
99  				} else if (lp[0] == 1 && u < npatchesPerClass) {
100 					// DisplayUtilities.display(ResizeProcessor.resample(ip,
101 					// 128, 128).normalise());
102 					urbanPatches[u] = ip;
103 					u++;
104 				}
105 			}
106 		}
107 	}
108 
109 	void learnDictionary(int dictSize) {
110 		whitening.train(featurePatches);
111 		whitenedFeaturePatches = whitening.whiten(featurePatches);
112 
113 		final SphericalKMeans skm = new SphericalKMeans(dictSize, 40);
114 		final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches);
115 		this.dictionary = res.centroids;
116 	}
117 
118 	double[] representPatch(double[] patch) {
119 		final double[] wp = whitening.whiten(patch);
120 
121 		final double[] z = new double[dictionary.length];
122 		for (int i = 0; i < z.length; i++) {
123 			double accum = 0;
124 			for (int j = 0; j < patch.length; j++) {
125 				accum += wp[j] * dictionary[i][j];
126 			}
127 
128 			z[i] = Math.max(0, Math.abs(accum) - 0.5);
129 		}
130 		return z;
131 	}
132 
133 	@Override
134 	public DoubleFV extractFeature(FImage bigpatch) {
135 		final double[][][] pfeatures = new double[3][3][dictionary.length];
136 		final int[][] pcount = new int[3][3];
137 
138 		final FImage tmp = new FImage(patchSize, patchSize);
139 		for (int y = 0; y < bigPatchSize - patchSize; y++) {
140 			final int yp = (int) ((y / (double) (bigPatchSize - patchSize)) * 3);
141 
142 			for (int x = 0; x < bigPatchSize - patchSize; x++) {
143 				final int xp = (int) ((x / (double) (bigPatchSize - patchSize)) * 3);
144 
145 				final double[] p = bigpatch.extractROI(x, y, tmp).getDoublePixelVector();
146 				ArrayUtils.sum(pfeatures[yp][xp], representPatch(p));
147 				pcount[yp][xp]++;
148 
149 			}
150 		}
151 
152 		final double[] vector = new double[3 * 3 * dictionary.length];
153 
154 		for (int y = 0; y < 3; y++)
155 			for (int x = 0; x < 3; x++)
156 				for (int i = 0; i < dictionary.length; i++)
157 					if (pfeatures[y][x][i] > 0)
158 						vector[3 * x + y * 3 * 3 + i] = pfeatures[y][x][i] / pcount[y][x];
159 
160 		return new DoubleFV(vector);
161 	}
162 
163 	public static void main(String[] args) throws IOException {
164 		final TestImageClass tic = new TestImageClass();
165 
166 		final FImage trainPhoto = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
167 				"/Users/jon/Desktop/images50cm4band/sp7034.jpeg"))));
168 		final FImage trainClass = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
169 				"/Users/jon/Desktop/images50cm4band/sp7034-classes.PNG"))));
170 
171 		tic.extractFeaturePatches(trainPhoto, 20000, 8);
172 		tic.extractClassifierTrainingPatches(trainPhoto, trainClass, 1000, 32);
173 		tic.learnDictionary(100);
174 
175 		// Note: should really use sparse version!!
176 		/*
177 		 * final LiblinearAnnotator<FImage, Boolean> ann = new
178 		 * LiblinearAnnotator<FImage, Boolean>(tic, Mode.MULTICLASS,
179 		 * SolverType.L2R_L2LOSS_SVC, 1, 0.0001);
180 		 * 
181 		 * final MapBackedDataset<Boolean, ListBackedDataset<FImage>, FImage>
182 		 * data = new MapBackedDataset<Boolean, ListBackedDataset<FImage>,
183 		 * FImage>(); data.add(true, new
184 		 * ListBackedDataset<FImage>(Arrays.asList(tic.ruralPatches)));
185 		 * data.add(false, new
186 		 * ListBackedDataset<FImage>(Arrays.asList(tic.urbanPatches)));
187 		 * ann.train(data);
188 		 */
189 		final FImage test = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
190 				"/Users/jon/Desktop/images50cm4band/test.jpeg")))).normalise();
191 
192 		/*
193 		 * final FImage result = test.extractCenter(test.width - 32, test.height
194 		 * - 32); final FImage tmp = new FImage(32, 32); for (int y = 0; y <
195 		 * test.height - 32; y++) { for (int x = 0; x < test.width - 32; x++) {
196 		 * test.extractROI(x, y, tmp);
197 		 * 
198 		 * final ClassificationResult<Boolean> r = ann.classify(tmp); final
199 		 * Boolean clz = r.getPredictedClasses().iterator().next();
200 		 * 
201 		 * if (clz) result.pixels[y][x] = 1;
202 		 * 
203 		 * DisplayUtilities.displayName(result, "result"); } }
204 		 */
205 
206 		final FImage tmp = new FImage(8 * 10, 8 * 10);
207 		for (int i = 0, y = 0; y < 10; y++) {
208 			for (int x = 0; x < 10; x++, i++) {
209 				final FImage p = new FImage(tic.dictionary[i], 8, 8);
210 				p.divideInplace(2 * Math.max(p.min(), p.max()));
211 				p.addInplace(0.5f);
212 				tmp.drawImage(p, x * 8, y * 8);
213 			}
214 		}
215 		DisplayUtilities.display(tmp);
216 	}
217 }