001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos;
031
032import java.io.File;
033import java.io.FileNotFoundException;
034import java.io.IOException;
035import java.util.ArrayList;
036import java.util.List;
037import java.util.Random;
038import java.util.Scanner;
039
040import org.openimaj.feature.FloatFV;
041import org.openimaj.io.IOUtils;
042import org.openimaj.ml.linear.projection.LargeMarginDimensionalityReduction;
043
044import Jama.Matrix;
045
046import com.jmatio.io.MatFileReader;
047import com.jmatio.types.MLSingle;
048
049public class FVFWExperiment {
050        // private static final String FOLDER =
051        // "lfw-centre-affine-pdsift-pca64-augm-fv512/";
052        // private static final String FOLDER = "lfw-centre-affine-matlab-fisher/";
053        private static final String FOLDER = "matlab-fvs/";
054
055        static class FacePair {
056                boolean same;
057                File firstFV;
058                File secondFV;
059
060                public FacePair(File first, File second, boolean same) {
061                        this.firstFV = first;
062                        this.secondFV = second;
063                        this.same = same;
064                }
065
066                FloatFV loadFirst() throws IOException {
067                        return IOUtils.read(firstFV, FloatFV.class);
068                }
069
070                FloatFV loadSecond() throws IOException {
071                        return IOUtils.read(secondFV, FloatFV.class);
072                }
073        }
074
075        static class Subset {
076                List<FacePair> testPairs = new ArrayList<FacePair>();
077                List<FacePair> trainingPairs = new ArrayList<FacePair>();
078        }
079
080        static List<Subset> loadSubsets() throws IOException {
081                final List<Subset> subsets = new ArrayList<Subset>();
082
083                for (int i = 0; i < 10; i++)
084                        subsets.add(new Subset());
085
086                loadPairs(new File("/Users/jon/Data/lfw/pairs.txt"), subsets);
087                loadPeople(new File("/Users/jon/Data/lfw/people.txt"), subsets);
088
089                return subsets;
090        }
091
092        private static void loadPairs(File file, List<Subset> subsets) throws FileNotFoundException {
093                final Scanner sc = new Scanner(file);
094
095                final int nsets = sc.nextInt();
096                final int nhpairs = sc.nextInt();
097
098                if (nsets != 10 || nhpairs != 300) {
099                        sc.close();
100                        throw new RuntimeException();
101                }
102
103                for (int s = 0; s < 10; s++) {
104                        for (int i = 0; i < 300; i++) {
105                                final String name = sc.next();
106                                final int firstIdx = sc.nextInt();
107                                final int secondIdx = sc.nextInt();
108
109                                final File first = new File(file.getParentFile(), FOLDER + name
110                                                + "/" + name + String.format("_%04d.bin", firstIdx));
111                                final File second = new File(file.getParentFile(), FOLDER + name
112                                                + "/" + name + String.format("_%04d.bin", secondIdx));
113
114                                subsets.get(s).testPairs.add(new FacePair(first, second, true));
115                        }
116
117                        for (int i = 0; i < 300; i++) {
118                                final String firstName = sc.next();
119                                final int firstIdx = sc.nextInt();
120                                final String secondName = sc.next();
121                                final int secondIdx = sc.nextInt();
122
123                                final File first = new File(file.getParentFile(), FOLDER
124                                                + firstName
125                                                + "/" + firstName + String.format("_%04d.bin", firstIdx));
126                                final File second = new File(file.getParentFile(), FOLDER
127                                                + secondName
128                                                + "/" + secondName + String.format("_%04d.bin", secondIdx));
129
130                                subsets.get(s).testPairs.add(new FacePair(first, second, false));
131                        }
132                }
133
134                sc.close();
135        }
136
137        private static void loadPeople(File file, List<Subset> subsets) throws FileNotFoundException {
138                final Scanner sc = new Scanner(file);
139
140                final int nsets = sc.nextInt();
141
142                if (nsets != 10) {
143                        sc.close();
144                        throw new RuntimeException();
145                }
146
147                for (int s = 0; s < 10; s++) {
148                        final int nnames = sc.nextInt();
149                        final List<File> files = new ArrayList<File>(nnames);
150                        for (int i = 0; i < nnames; i++) {
151                                final String name = sc.next();
152                                final int numPeople = sc.nextInt();
153                                for (int j = 1; j <= numPeople; j++) {
154                                        final File f = new File(file.getParentFile(), FOLDER + name
155                                                        + "/" + name + String.format("_%04d.bin", j));
156
157                                        files.add(f);
158                                }
159                        }
160
161                        for (int i = 0; i < files.size(); i++) {
162                                final File first = files.get(i);
163                                for (int j = i + 1; j < files.size(); j++) {
164                                        final File second = files.get(j);
165
166                                        final boolean same = first.getName().substring(0, first.getName().lastIndexOf("_"))
167                                                        .equals(second.getName().substring(0, second.getName().lastIndexOf("_")));
168
169                                        subsets.get(s).trainingPairs.add(new FacePair(first, second, same));
170                                        subsets.get(s).trainingPairs.add(new FacePair(second, first, same));
171                                }
172                        }
173                }
174
175                sc.close();
176        }
177
178        static Subset createExperimentalFold(List<Subset> subsets, int foldIdx) {
179                final Subset subset = new Subset();
180                // testing data is from the indexed fold
181                subset.testPairs = subsets.get(foldIdx).testPairs;
182
183                // training data is from the other folds
184                final List<FacePair> training = new ArrayList<FacePair>();
185                for (int i = 0; i < foldIdx; i++)
186                        training.addAll(subsets.get(i).trainingPairs);
187                for (int i = foldIdx + 1; i < subsets.size(); i++)
188                        training.addAll(subsets.get(i).trainingPairs);
189
190                subset.trainingPairs = reorder(training);
191
192                return subset;
193        }
194
195        private static List<FacePair> reorder(List<FacePair> training) {
196                final List<FacePair> trainingTrue = new ArrayList<FacePair>();
197                final List<FacePair> trainingFalse = new ArrayList<FacePair>();
198
199                for (final FacePair fp : training) {
200                        if (fp.same)
201                                trainingTrue.add(fp);
202                        else
203                                trainingFalse.add(fp);
204                }
205
206                resample(trainingTrue, 4000000);
207                resample(trainingFalse, 4000000);
208
209                final List<FacePair> trainingResorted = new ArrayList<FacePair>();
210                for (int i = 0; i < trainingTrue.size(); i++) {
211                        trainingResorted.add(trainingTrue.get(i));
212                        trainingResorted.add(trainingFalse.get(i));
213                }
214
215                return trainingResorted;
216        }
217
218        private static void resample(List<FacePair> pairs, int sz) {
219                final List<FacePair> oldPairs = new ArrayList<FVFWExperiment.FacePair>(sz);
220                oldPairs.addAll(pairs);
221                pairs.clear();
222
223                final Random r = new Random();
224
225                for (int i = 0; i < sz; i++) {
226                        pairs.add(oldPairs.get(r.nextInt(oldPairs.size())));
227                }
228        }
229
230        public static void main(String[] args) throws IOException {
231                final List<Subset> subsets = loadSubsets();
232                final Subset fold = createExperimentalFold(subsets, 1);
233
234                // // final LargeMarginDimensionalityReduction lmdr = new
235                // // LargeMarginDimensionalityReduction(128);
236                // final LargeMarginDimensionalityReduction lmdr = loadMatlabPCAW();
237                //
238                // final double[][] fInit = new double[1000][];
239                // final double[][] sInit = new double[1000][];
240                // final boolean[] same = new boolean[1000];
241                // for (int i = 0; i < 1000; i++) {
242                // final FacePair p =
243                // fold.trainingPairs.get(i);
244                // fInit[i] = p.loadFirst().asDoubleVector();
245                // sInit[i] = p.loadSecond().asDoubleVector();
246                // same[i] = p.same;
247                //
248                // for (int j = 0; j < fInit[i].length; j++) {
249                // if (Double.isInfinite(fInit[i][j]) || Double.isNaN(fInit[i][j]))
250                // throw new RuntimeException("" + fold.trainingPairs.get(i).firstFV);
251                // if (Double.isInfinite(sInit[i][j]) || Double.isNaN(sInit[i][j]))
252                // throw new RuntimeException("" + fold.trainingPairs.get(i).secondFV);
253                // }
254                // }
255                //
256                // System.out.println("LMDR Init");
257                // lmdr.recomputeBias(fInit, sInit, same);
258                // // lmdr.initialise(fInit, sInit, same);
259                // IOUtils.writeToFile(lmdr, new
260                // File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw-init.bin"));
261                // // final LargeMarginDimensionalityReduction lmdr = IOUtils
262                // // .readFromFile(new File("/Users/jon/Data/lfw/lmdr-init.bin"));
263                //
264                // for (int i = 0; i < 1e6; i++) {
265                // if (i % 100 == 0)
266                // System.out.println("Iter " + i);
267                // final FacePair p = fold.trainingPairs.get(i);
268                // lmdr.step(p.loadFirst().asDoubleVector(),
269                // p.loadSecond().asDoubleVector(), p.same);
270                // }
271                // IOUtils.writeToFile(lmdr, new
272                // File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw.bin"));
273
274                final LargeMarginDimensionalityReduction lmdr =
275                                IOUtils.readFromFile(new
276                                                File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw.bin"));
277                // final LargeMarginDimensionalityReduction lmdr = loadMatlabLMDR();
278                // final LargeMarginDimensionalityReduction lmdr = loadMatlabPCAW();
279
280                final double[][] first = new double[fold.testPairs.size()][];
281                final double[][] second = new double[fold.testPairs.size()][];
282                final boolean[] same = new boolean[fold.testPairs.size()];
283                for (int j = 0; j < same.length; j++) {
284                        final FacePair p = fold.testPairs.get(j);
285                        first[j] = p.loadFirst().asDoubleVector();
286                        second[j] = p.loadSecond().asDoubleVector();
287                        same[j] = p.same;
288                }
289                // System.out.println("Current bias: " + lmdr.getBias());
290                // lmdr.recomputeBias(first, second, same);
291                // System.out.println("Best bias: " + lmdr.getBias());
292
293                double correct = 0;
294                double count = 0;
295                for (int j = 0; j < same.length; j++) {
296                        final boolean pred = lmdr.classify(first[j],
297                                        second[j]);
298
299                        if (pred == same[j])
300                                correct++;
301                        count++;
302                }
303                System.out.println(lmdr.getBias() + " " + (correct / count));
304        }
305
306        // private static double[] reorder(double[] in) {
307        // final double[] out = new double[in.length];
308        // final int D = 64;
309        // final int K = 512;
310        // for (int k = 0; k < K; k++) {
311        // for (int j = 0; j < D; j++) {
312        // out[k * D + j] = in[k * 2 * D + j];
313        // out[k * D + j + D * K] = in[k * 2 * D + j + D];
314        // }
315        // }
316        // return out;
317        // }
318
319        private static LargeMarginDimensionalityReduction loadMatlabLMDR() throws IOException {
320                final LargeMarginDimensionalityReduction lmdr = new LargeMarginDimensionalityReduction(128);
321
322                final MatFileReader reader = new MatFileReader(new File("/Users/jon/lmdr.mat"));
323                final MLSingle W = (MLSingle) reader.getContent().get("W");
324                final MLSingle b = (MLSingle) reader.getContent().get("b");
325
326                lmdr.setBias(b.get(0, 0));
327
328                final Matrix proj = new Matrix(W.getM(), W.getN());
329                for (int j = 0; j < W.getN(); j++) {
330                        for (int i = 0; i < W.getM(); i++) {
331                                proj.set(i, j, W.get(i, j));
332                        }
333                }
334
335                lmdr.setTransform(proj);
336
337                return lmdr;
338        }
339
340        private static LargeMarginDimensionalityReduction loadMatlabPCAW() throws IOException {
341                final LargeMarginDimensionalityReduction lmdr = new LargeMarginDimensionalityReduction(128);
342
343                final MatFileReader reader = new MatFileReader(new File("/Users/jon/pcaw.mat"));
344                final MLSingle W = (MLSingle) reader.getContent().get("proj");
345
346                lmdr.setBias(169.6264190673828);
347
348                final Matrix proj = new Matrix(W.getM(), W.getN());
349                for (int j = 0; j < W.getN(); j++) {
350                        for (int i = 0; i < W.getM(); i++) {
351                                proj.set(i, j, W.get(i, j));
352                        }
353                }
354
355                lmdr.setTransform(proj);
356
357                return lmdr;
358        }
359}