001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.model;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.util.ArrayList;
036import java.util.HashMap;
037import java.util.List;
038import java.util.Map;
039import java.util.Map.Entry;
040
041import org.openimaj.citation.annotation.Reference;
042import org.openimaj.citation.annotation.ReferenceType;
043import org.openimaj.data.dataset.GroupedDataset;
044import org.openimaj.data.dataset.ListDataset;
045import org.openimaj.feature.DoubleFV;
046import org.openimaj.feature.FeatureExtractor;
047import org.openimaj.image.FImage;
048import org.openimaj.image.feature.FImage2DoubleFV;
049import org.openimaj.io.ReadWriteableBinary;
050import org.openimaj.math.matrix.algorithm.LinearDiscriminantAnalysis;
051import org.openimaj.math.matrix.algorithm.pca.PrincipalComponentAnalysis;
052import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
053import org.openimaj.ml.training.BatchTrainer;
054import org.openimaj.util.array.ArrayUtils;
055import org.openimaj.util.pair.IndependentPair;
056
057import Jama.Matrix;
058
059/**
060 * Implementation of Fisher Images (aka "FisherFaces"). PCA is used to avoid the
061 * singular within-class scatter matrix.
062 * 
063 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
064 */
065@Reference(
066                type = ReferenceType.Article,
067                author = { "Belhumeur, Peter N.", "Hespanha, Jo\\~{a}o P.", "Kriegman, David J." },
068                title = "Eigenfaces vs. Fisherfaces: Recognition Using Class Specific Linear Projection",
069                year = "1997",
070                journal = "IEEE Trans. Pattern Anal. Mach. Intell.",
071                pages = { "711", "", "720" },
072                url = "http://dx.doi.org/10.1109/34.598228",
073                month = "July",
074                number = "7",
075                publisher = "IEEE Computer Society",
076                volume = "19",
077                customData = {
078                                "issn", "0162-8828",
079                                "numpages", "10",
080                                "doi", "10.1109/34.598228",
081                                "acmid", "261512",
082                                "address", "Washington, DC, USA",
083                                "keywords", "Appearance-based vision, face recognition, illumination invariance, Fisher's linear discriminant."
084                })
085public class FisherImages implements BatchTrainer<IndependentPair<?, FImage>>,
086                FeatureExtractor<DoubleFV, FImage>,
087                ReadWriteableBinary
088{
089        private int numComponents;
090        private int width;
091        private int height;
092        private Matrix basis;
093        private double[] mean;
094
095        /**
096         * Construct with the given number of components.
097         * 
098         * @param numComponents
099         *            the number of components
100         */
101        public FisherImages(int numComponents) {
102                this.numComponents = numComponents;
103        }
104
105        @Override
106        public void readBinary(DataInput in) throws IOException {
107                width = in.readInt();
108                height = in.readInt();
109                numComponents = in.readInt();
110        }
111
112        @Override
113        public byte[] binaryHeader() {
114                return "FisI".getBytes();
115        }
116
117        @Override
118        public void writeBinary(DataOutput out) throws IOException {
119                out.writeInt(width);
120                out.writeInt(height);
121                out.writeInt(numComponents);
122        }
123
124        /**
125         * Train on a map of data.
126         * 
127         * @param data
128         *            the data
129         */
130        public void train(Map<?, ? extends List<FImage>> data) {
131                final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
132
133                for (final Entry<?, ? extends List<FImage>> e : data.entrySet()) {
134                        for (final FImage i : e.getValue()) {
135                                list.add(IndependentPair.pair(e.getKey(), i));
136                        }
137                }
138
139                train(list);
140        }
141
142        /**
143         * Train on a grouped dataset.
144         * 
145         * @param <KEY>
146         *            The group type
147         * @param data
148         *            the data
149         */
150        public <KEY> void train(GroupedDataset<KEY, ? extends ListDataset<FImage>, FImage> data) {
151                final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
152
153                for (final KEY e : data.getGroups()) {
154                        for (final FImage i : data.getInstances(e)) {
155                                list.add(IndependentPair.pair(e, i));
156                        }
157                }
158
159                train(list);
160        }
161
162        @Override
163        public void train(List<? extends IndependentPair<?, FImage>> data) {
164                width = data.get(0).secondObject().width;
165                height = data.get(0).secondObject().height;
166
167                final Map<Object, List<double[]>> mapData = new HashMap<Object, List<double[]>>();
168                final List<double[]> listData = new ArrayList<double[]>();
169                for (final IndependentPair<?, FImage> item : data) {
170                        List<double[]> fvs = mapData.get(item.firstObject());
171                        if (fvs == null)
172                                mapData.put(item.firstObject(), fvs = new ArrayList<double[]>());
173
174                        final double[] fv = FImage2DoubleFV.INSTANCE.extractFeature(item.getSecondObject()).values;
175                        fvs.add(fv);
176                        listData.add(fv);
177                }
178
179                final PrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(numComponents);
180                pca.learnBasis(listData);
181
182                final List<double[][]> ldaData = new ArrayList<double[][]>(mapData.size());
183                for (final Entry<?, List<double[]>> e : mapData.entrySet()) {
184                        final List<double[]> vecs = e.getValue();
185                        final double[][] classData = new double[vecs.size()][];
186
187                        for (int i = 0; i < classData.length; i++) {
188                                classData[i] = pca.project(vecs.get(i));
189                        }
190
191                        ldaData.add(classData);
192                }
193
194                final LinearDiscriminantAnalysis lda = new LinearDiscriminantAnalysis(numComponents);
195                lda.learnBasis(ldaData);
196
197                basis = pca.getBasis().times(lda.getBasis());
198                mean = pca.getMean();
199        }
200
201        private double[] project(double[] vector) {
202                final Matrix vec = new Matrix(1, vector.length);
203                final double[][] vecarr = vec.getArray();
204
205                for (int i = 0; i < vector.length; i++)
206                        vecarr[0][i] = vector[i] - mean[i];
207
208                return vec.times(basis).getColumnPackedCopy();
209        }
210
211        @Override
212        public DoubleFV extractFeature(FImage object) {
213                return new DoubleFV(project(FImage2DoubleFV.INSTANCE.extractFeature(object).values));
214        }
215
216        /**
217         * Get a specific basis vector as a double array. The returned array
218         * contains a copy of the data.
219         * 
220         * @param index
221         *            the index of the vector
222         * 
223         * @return the eigenvector
224         */
225        public double[] getBasisVector(int index) {
226                final double[] pc = new double[basis.getRowDimension()];
227                final double[][] data = basis.getArray();
228
229                for (int r = 0; r < pc.length; r++)
230                        pc[r] = data[r][index];
231
232                return pc;
233        }
234
235        /**
236         * Draw an eigenvector as an image
237         * 
238         * @param num
239         *            the index of the eigenvector to draw.
240         * @return an image showing the eigenvector.
241         */
242        public FImage visualise(int num) {
243                return new FImage(ArrayUtils.reshapeFloat(getBasisVector(num), width, height));
244        }
245}