001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.workinprogress.featlearn.cifarexps; 031 032import java.io.IOException; 033import java.util.List; 034 035import org.openimaj.image.DisplayUtilities; 036import org.openimaj.image.MBFImage; 037import org.openimaj.image.pixel.sampling.RectangleSampler; 038import org.openimaj.math.geometry.shape.Rectangle; 039import org.openimaj.math.matrix.algorithm.whitening.WhiteningTransform; 040import org.openimaj.math.matrix.algorithm.whitening.ZCAWhitening; 041import org.openimaj.math.statistics.normalisation.Normaliser; 042import org.openimaj.math.statistics.normalisation.PerExampleMeanCenterVar; 043import org.openimaj.ml.clustering.kmeans.SphericalKMeans; 044import org.openimaj.ml.clustering.kmeans.SphericalKMeans.IterationResult; 045import org.openimaj.ml.clustering.kmeans.SphericalKMeansResult; 046import org.openimaj.util.function.Operation; 047 048public class KMeansExp1 extends CIFARExperimentFramework { 049 Normaliser patchNorm = new PerExampleMeanCenterVar(10.0 / 255.0); 050 WhiteningTransform whitening = new ZCAWhitening(0.1, patchNorm); 051 int numCentroids = 1600; 052 int numIters = 10; 053 054 private double[][] dictionary; 055 final RectangleSampler rs = new RectangleSampler(new Rectangle(0, 0, 32, 32), 1, 1, patchSize, patchSize); 056 final List<Rectangle> rectangles = rs.allRectangles(); 057 058 @Override 059 protected void learnFeatures(double[][] patches) { 060 whitening.train(patches); 061 062 final double[][] whitenedFeaturePatches = whitening.whiten(patches); 063 final SphericalKMeans skm = new SphericalKMeans(numCentroids, numIters); 064 skm.addIterationListener(new Operation<SphericalKMeans.IterationResult>() { 065 @Override 066 public void perform(IterationResult object) { 067 System.out.println("KMeans iteration " + object.iteration + " / " + numIters); 068 DisplayUtilities.display(drawCentroids(object.result.centroids)); 069 } 070 }); 071 final SphericalKMeansResult res = skm.cluster(whitenedFeaturePatches); 072 this.dictionary = res.centroids; 073 074 DisplayUtilities.display(drawCentroids(this.dictionary)); 075 } 076 077 MBFImage drawCentroids(double[][] centroids) { 078 final int wh = (int) Math.sqrt(numCentroids); 079 final MBFImage tmp = new MBFImage(wh * (patchSize + 1) + 1, wh * (patchSize + 1) + 1); 080 final float mn = -1.0f; 081 final float mx = +1.0f; 082 tmp.fill(new Float[] { mx, mx, mx }); 083 084 for (int i = 0, y = 0; y < wh; y++) { 085 for (int x = 0; x < wh; x++, i++) { 086 final MBFImage p = new MBFImage(centroids[i], patchSize, patchSize, 3, false); 087 tmp.drawImage(p, x * (patchSize + 1) + 1, y * (patchSize + 1) + 1); 088 } 089 } 090 tmp.subtractInplace(mn); 091 tmp.divideInplace(mx - mn); 092 return tmp; 093 } 094 095 @Override 096 protected double[] extractFeatures(MBFImage image) { 097 double[][] patches = new double[rectangles.size()][]; 098 final MBFImage tmpImage = new MBFImage(this.patchSize, this.patchSize); 099 100 for (int i = 0; i < patches.length; i++) { 101 final Rectangle r = rectangles.get(i); 102 patches[i] = image.extractROI((int) r.x, (int) r.y, tmpImage).getDoublePixelVector(); 103 } 104 patches = whitening.whiten(patches); 105 patches = activation(patches); 106 107 // sum pooling 108 final double[] feature = pool(patches); 109 110 return feature; 111 } 112 113 private double[] pool(double[][] patches) { 114 final double[] feature = new double[dictionary.length * 4]; 115 final int sz = (int) Math.sqrt(patches.length); 116 final int hsz = sz / 2; 117 for (int j = 0; j < sz; j++) { 118 final int by = j < hsz ? 0 : 1; 119 for (int i = 0; i < sz; i++) { 120 final int bx = i < hsz ? 0 : 1; 121 122 final double[] p = patches[j * sz + i]; 123 for (int k = 0; k < p.length; k++) 124 feature[2 * dictionary.length * by + dictionary.length * bx + k] += p[k]; 125 } 126 } 127 return feature; 128 } 129 130 // private double[][] activation(double[][] p) { 131 // final double[][] c = this.dictionary; 132 // final double[][] result = new double[p.length][c.length]; 133 // 134 // final double[] z = new double[c.length]; 135 // for (int i = 0; i < p.length; i++) { 136 // final double[] x = p[i]; 137 // double mu = 0; 138 // for (int k = 0; k < z.length; k++) { 139 // z[k] = 0; 140 // for (int j = 0; j < x.length; j++) { 141 // final double d = x[j] - c[k][j]; 142 // z[k] += d * d; 143 // } 144 // z[k] = Math.sqrt(z[k]); 145 // mu += z[k]; 146 // } 147 // 148 // mu /= z.length; 149 // 150 // for (int k = 0; k < z.length; k++) { 151 // result[i][k] = Math.max(0, mu - z[k]); 152 // } 153 // } 154 // 155 // return result; 156 // } 157 158 private double[][] activation(double[][] p) { 159 final double[][] c = this.dictionary; 160 final double[][] result = new double[p.length][c.length]; 161 162 for (int i = 0; i < p.length; i++) { 163 final double[] x = p[i]; 164 165 for (int k = 0; k < c.length; k++) { 166 double dx = 0; 167 for (int j = 0; j < x.length; j++) { 168 dx += c[k][j] * x[j]; 169 } 170 result[i][k] = Math.max(0, Math.abs(dx) - 0.5); 171 } 172 } 173 174 return result; 175 } 176 177 public static void main(String[] args) throws IOException { 178 new KMeansExp1().run(); 179 } 180}