001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.featlearn.cifarexps;
031
032import java.io.IOException;
033import java.util.List;
034
035import org.openimaj.image.DisplayUtilities;
036import org.openimaj.image.MBFImage;
037import org.openimaj.image.pixel.sampling.RectangleSampler;
038import org.openimaj.math.geometry.shape.Rectangle;
039import org.openimaj.math.matrix.algorithm.whitening.WhiteningTransform;
040import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening;
041import org.openimaj.math.statistics.normalisation.Normaliser;
042import org.openimaj.math.statistics.normalisation.PerExampleMeanCenterVar;
043import org.openimaj.ml.clustering.kmeans.SphericalKMeans;
044import org.openimaj.ml.clustering.kmeans.SphericalKMeans.IterationResult;
045import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult;
046import org.openimaj.util.function.Operation;
047
048public class KMeansExp1 extends CIFARExperimentFramework {
049        Normaliser patchNorm = new PerExampleMeanCenterVar(10.0 / 255.0);
050        WhiteningTransform whitening = new ZCAWhitening(0.1, patchNorm);
051        int numCentroids = 1600;
052        int numIters = 10;
053
054        private double[][] dictionary;
055        final RectangleSampler rs = new RectangleSampler(new Rectangle(0, 0, 32, 32), 1, 1, patchSize, patchSize);
056        final List<Rectangle> rectangles = rs.allRectangles();
057
058        @Override
059        protected void learnFeatures(double[][] patches) {
060                whitening.train(patches);
061
062                final double[][] whitenedFeaturePatches = whitening.whiten(patches);
063                final SphericalKMeans skm = new SphericalKMeans(numCentroids, numIters);
064                skm.addIterationListener(new Operation<SphericalKMeans.IterationResult>() {
065                        @Override
066                        public void perform(IterationResult object) {
067                                System.out.println("KMeans iteration " + object.iteration + " / " + numIters);
068                                DisplayUtilities.display(drawCentroids(object.result.centroids));
069                        }
070                });
071                final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches);
072                this.dictionary = res.centroids;
073
074                DisplayUtilities.display(drawCentroids(this.dictionary));
075        }
076
077        MBFImage drawCentroids(double[][] centroids) {
078                final int wh = (int) Math.sqrt(numCentroids);
079                final MBFImage tmp = new MBFImage(wh * (patchSize + 1) + 1, wh * (patchSize + 1) + 1);
080                final float mn = -1.0f;
081                final float mx = +1.0f;
082                tmp.fill(new Float[] { mx, mx, mx });
083
084                for (int i = 0, y = 0; y < wh; y++) {
085                        for (int x = 0; x < wh; x++, i++) {
086                                final MBFImage p = new MBFImage(centroids[i], patchSize, patchSize, 3, false);
087                                tmp.drawImage(p, x * (patchSize + 1) + 1, y * (patchSize + 1) + 1);
088                        }
089                }
090                tmp.subtractInplace(mn);
091                tmp.divideInplace(mx - mn);
092                return tmp;
093        }
094
095        @Override
096        protected double[] extractFeatures(MBFImage image) {
097                double[][] patches = new double[rectangles.size()][];
098                final MBFImage tmpImage = new MBFImage(this.patchSize, this.patchSize);
099
100                for (int i = 0; i < patches.length; i++) {
101                        final Rectangle r = rectangles.get(i);
102                        patches[i] = image.extractROI((int) r.x, (int) r.y, tmpImage).getDoublePixelVector();
103                }
104                patches = whitening.whiten(patches);
105                patches = activation(patches);
106
107                // sum pooling
108                final double[] feature = pool(patches);
109
110                return feature;
111        }
112
113        private double[] pool(double[][] patches) {
114                final double[] feature = new double[dictionary.length * 4];
115                final int sz = (int) Math.sqrt(patches.length);
116                final int hsz = sz / 2;
117                for (int j = 0; j < sz; j++) {
118                        final int by = j < hsz ? 0 : 1;
119                        for (int i = 0; i < sz; i++) {
120                                final int bx = i < hsz ? 0 : 1;
121
122                                final double[] p = patches[j * sz + i];
123                                for (int k = 0; k < p.length; k++)
124                                        feature[2 * dictionary.length * by + dictionary.length * bx + k] += p[k];
125                        }
126                }
127                return feature;
128        }
129
130        // private double[][] activation(double[][] p) {
131        // final double[][] c = this.dictionary;
132        // final double[][] result = new double[p.length][c.length];
133        //
134        // final double[] z = new double[c.length];
135        // for (int i = 0; i < p.length; i++) {
136        // final double[] x = p[i];
137        // double mu = 0;
138        // for (int k = 0; k < z.length; k++) {
139        // z[k] = 0;
140        // for (int j = 0; j < x.length; j++) {
141        // final double d = x[j] - c[k][j];
142        // z[k] += d * d;
143        // }
144        // z[k] = Math.sqrt(z[k]);
145        // mu += z[k];
146        // }
147        //
148        // mu /= z.length;
149        //
150        // for (int k = 0; k < z.length; k++) {
151        // result[i][k] = Math.max(0, mu - z[k]);
152        // }
153        // }
154        //
155        // return result;
156        // }
157
158        private double[][] activation(double[][] p) {
159                final double[][] c = this.dictionary;
160                final double[][] result = new double[p.length][c.length];
161
162                for (int i = 0; i < p.length; i++) {
163                        final double[] x = p[i];
164
165                        for (int k = 0; k < c.length; k++) {
166                                double dx = 0;
167                                for (int j = 0; j < x.length; j++) {
168                                        dx += c[k][j] * x[j];
169                                }
170                                result[i][k] = Math.max(0, Math.abs(dx) - 0.5);
171                        }
172                }
173
174                return result;
175        }
176
177        public static void main(String[] args) throws IOException {
178                new KMeansExp1().run();
179        }
180}