001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.featlearn;
031
032import java.io.File;
033import java.io.IOException;
034import java.util.Random;
035
036import org.openimaj.feature.DoubleFV;
037import org.openimaj.feature.FeatureExtractor;
038import org.openimaj.image.DisplayUtilities;
039import org.openimaj.image.FImage;
040import org.openimaj.image.ImageUtilities;
041import org.openimaj.image.processing.resize.ResizeProcessor;
042import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
043import org.openimaj.math.statistics.normalisation.PerExampleMeanCenter;
044import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
045import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
046import org.openimaj.util.array.ArrayUtils;
047
048public class TestImageClass implements FeatureExtractor<DoubleFV, FImage> {
049        final Random rng = new Random(0);
050        double[][] featurePatches;
051        FImage[] urbanPatches;
052        FImage[] ruralPatches;
053        int patchSize;
054        int bigPatchSize;
055
056        ZCAWhitening whitening = new ZCAWhitening(0.1, new PerExampleMeanCenter());
057        double[][] dictionary;
058        private double[][] whitenedFeaturePatches;
059
060        void extractFeaturePatches(FImage image, int npatches, int sz) {
061                patchSize = sz;
062                featurePatches = new double[npatches][];
063                for (int i = 0; i < npatches; i++) {
064                        final int x = rng.nextInt(image.width - sz - 1);
065                        final int y = rng.nextInt(image.height - sz - 1);
066
067                        final double[] ip = image.extractROI(x, y, sz, sz).getDoublePixelVector();
068                        featurePatches[i] = ip;
069                }
070        }
071
072        void extractClassifierTrainingPatches(FImage image, FImage labels, int npatchesPerClass, int sz) {
073                bigPatchSize = sz;
074                urbanPatches = new FImage[npatchesPerClass];
075                ruralPatches = new FImage[npatchesPerClass];
076
077                int u = 0;
078                int r = 0;
079
080                while (u < npatchesPerClass || r < npatchesPerClass) {
081                        final int x = rng.nextInt(image.width - sz - 1);
082                        final int y = rng.nextInt(image.height - sz - 1);
083
084                        final FImage ip = image.extractROI(x, y, sz, sz);
085                        final float[] lp = labels.extractROI(x, y, sz, sz).getFloatPixelVector();
086
087                        boolean same = true;
088                        for (int i = 0; i < sz * sz; i++) {
089                                if (lp[i] != lp[0]) {
090                                        same = false;
091                                        break;
092                                }
093                        }
094
095                        if (same) {
096                                if (lp[0] == 0 && r < npatchesPerClass) {
097                                        ruralPatches[r] = ip;
098                                        r++;
099                                } else if (lp[0] == 1 && u < npatchesPerClass) {
100                                        // DisplayUtilities.display(ResizeProcessor.resample(ip,
101                                        // 128, 128).normalise());
102                                        urbanPatches[u] = ip;
103                                        u++;
104                                }
105                        }
106                }
107        }
108
109        void learnDictionary(int dictSize) {
110                whitening.train(featurePatches);
111                whitenedFeaturePatches = whitening.whiten(featurePatches);
112
113                final SphericalKMeans skm = new SphericalKMeans(dictSize, 40);
114                final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches);
115                this.dictionary = res.centroids;
116        }
117
118        double[] representPatch(double[] patch) {
119                final double[] wp = whitening.whiten(patch);
120
121                final double[] z = new double[dictionary.length];
122                for (int i = 0; i < z.length; i++) {
123                        double accum = 0;
124                        for (int j = 0; j < patch.length; j++) {
125                                accum += wp[j] * dictionary[i][j];
126                        }
127
128                        z[i] = Math.max(0, Math.abs(accum) - 0.5);
129                }
130                return z;
131        }
132
133        @Override
134        public DoubleFV extractFeature(FImage bigpatch) {
135                final double[][][] pfeatures = new double[3][3][dictionary.length];
136                final int[][] pcount = new int[3][3];
137
138                final FImage tmp = new FImage(patchSize, patchSize);
139                for (int y = 0; y < bigPatchSize - patchSize; y++) {
140                        final int yp = (int) ((y / (double) (bigPatchSize - patchSize)) * 3);
141
142                        for (int x = 0; x < bigPatchSize - patchSize; x++) {
143                                final int xp = (int) ((x / (double) (bigPatchSize - patchSize)) * 3);
144
145                                final double[] p = bigpatch.extractROI(x, y, tmp).getDoublePixelVector();
146                                ArrayUtils.sum(pfeatures[yp][xp], representPatch(p));
147                                pcount[yp][xp]++;
148
149                        }
150                }
151
152                final double[] vector = new double[3 * 3 * dictionary.length];
153
154                for (int y = 0; y < 3; y++)
155                        for (int x = 0; x < 3; x++)
156                                for (int i = 0; i < dictionary.length; i++)
157                                        if (pfeatures[y][x][i] > 0)
158                                                vector[3 * x + y * 3 * 3 + i] = pfeatures[y][x][i] / pcount[y][x];
159
160                return new DoubleFV(vector);
161        }
162
163        public static void main(String[] args) throws IOException {
164                final TestImageClass tic = new TestImageClass();
165
166                final FImage trainPhoto = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
167                                "/Users/jon/Desktop/images50cm4band/sp7034.jpeg"))));
168                final FImage trainClass = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
169                                "/Users/jon/Desktop/images50cm4band/sp7034-classes.PNG"))));
170
171                tic.extractFeaturePatches(trainPhoto, 20000, 8);
172                tic.extractClassifierTrainingPatches(trainPhoto, trainClass, 1000, 32);
173                tic.learnDictionary(100);
174
175                // Note: should really use sparse version!!
176                /*
177                 * final LiblinearAnnotator<FImage, Boolean> ann = new
178                 * LiblinearAnnotator<FImage, Boolean>(tic, Mode.MULTICLASS,
179                 * SolverType.L2R_L2LOSS_SVC, 1, 0.0001);
180                 * 
181                 * final MapBackedDataset<Boolean, ListBackedDataset<FImage>, FImage>
182                 * data = new MapBackedDataset<Boolean, ListBackedDataset<FImage>,
183                 * FImage>(); data.add(true, new
184                 * ListBackedDataset<FImage>(Arrays.asList(tic.ruralPatches)));
185                 * data.add(false, new
186                 * ListBackedDataset<FImage>(Arrays.asList(tic.urbanPatches)));
187                 * ann.train(data);
188                 */
189                final FImage test = ResizeProcessor.halfSize(ResizeProcessor.halfSize(ImageUtilities.readF(new File(
190                                "/Users/jon/Desktop/images50cm4band/test.jpeg")))).normalise();
191
192                /*
193                 * final FImage result = test.extractCenter(test.width - 32, test.height
194                 * - 32); final FImage tmp = new FImage(32, 32); for (int y = 0; y <
195                 * test.height - 32; y++) { for (int x = 0; x < test.width - 32; x++) {
196                 * test.extractROI(x, y, tmp);
197                 * 
198                 * final ClassificationResult<Boolean> r = ann.classify(tmp); final
199                 * Boolean clz = r.getPredictedClasses().iterator().next();
200                 * 
201                 * if (clz) result.pixels[y][x] = 1;
202                 * 
203                 * DisplayUtilities.displayName(result, "result"); } }
204                 */
205
206                final FImage tmp = new FImage(8 * 10, 8 * 10);
207                for (int i = 0, y = 0; y < 10; y++) {
208                        for (int x = 0; x < 10; x++, i++) {
209                                final FImage p = new FImage(tic.dictionary[i], 8, 8);
210                                p.divideInplace(2 * Math.max(p.min(), p.max()));
211                                p.addInplace(0.5f);
212                                tmp.drawImage(p, x * 8, y * 8);
213                        }
214                }
215                DisplayUtilities.display(tmp);
216        }
217}