View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.image.model;
31  
32  import java.io.DataInput;
33  import java.io.DataOutput;
34  import java.io.IOException;
35  import java.util.ArrayList;
36  import java.util.HashMap;
37  import java.util.List;
38  import java.util.Map;
39  import java.util.Map.Entry;
40  
41  import org.openimaj.citation.annotation.Reference;
42  import org.openimaj.citation.annotation.ReferenceType;
43  import org.openimaj.data.dataset.GroupedDataset;
44  import org.openimaj.data.dataset.ListDataset;
45  import org.openimaj.feature.DoubleFV;
46  import org.openimaj.feature.FeatureExtractor;
47  import org.openimaj.image.FImage;
48  import org.openimaj.image.feature.FImage2DoubleFV;
49  import org.openimaj.io.ReadWriteableBinary;
50  import org.openimaj.math.matrix.algorithm.LinearDiscriminantAnalysis;
51  import org.openimaj.math.matrix.algorithm.pca.PrincipalComponentAnalysis;
52  import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
53  import org.openimaj.ml.training.BatchTrainer;
54  import org.openimaj.util.array.ArrayUtils;
55  import org.openimaj.util.pair.IndependentPair;
56  
57  import Jama.Matrix;
58  
59  /**
60   * Implementation of Fisher Images (aka "FisherFaces"). PCA is used to avoid the
61   * singular within-class scatter matrix.
62   * 
63   * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
64   */
65  @Reference(
66  		type = ReferenceType.Article,
67  		author = { "Belhumeur, Peter N.", "Hespanha, Jo\\~{a}o P.", "Kriegman, David J." },
68  		title = "Eigenfaces vs. Fisherfaces: Recognition Using Class Specific Linear Projection",
69  		year = "1997",
70  		journal = "IEEE Trans. Pattern Anal. Mach. Intell.",
71  		pages = { "711", "", "720" },
72  		url = "http://dx.doi.org/10.1109/34.598228",
73  		month = "July",
74  		number = "7",
75  		publisher = "IEEE Computer Society",
76  		volume = "19",
77  		customData = {
78  				"issn", "0162-8828",
79  				"numpages", "10",
80  				"doi", "10.1109/34.598228",
81  				"acmid", "261512",
82  				"address", "Washington, DC, USA",
83  				"keywords", "Appearance-based vision, face recognition, illumination invariance, Fisher's linear discriminant."
84  		})
85  public class FisherImages implements BatchTrainer<IndependentPair<?, FImage>>,
86  		FeatureExtractor<DoubleFV, FImage>,
87  		ReadWriteableBinary
88  {
89  	private int numComponents;
90  	private int width;
91  	private int height;
92  	private Matrix basis;
93  	private double[] mean;
94  
95  	/**
96  	 * Construct with the given number of components.
97  	 * 
98  	 * @param numComponents
99  	 *            the number of components
100 	 */
101 	public FisherImages(int numComponents) {
102 		this.numComponents = numComponents;
103 	}
104 
105 	@Override
106 	public void readBinary(DataInput in) throws IOException {
107 		width = in.readInt();
108 		height = in.readInt();
109 		numComponents = in.readInt();
110 	}
111 
112 	@Override
113 	public byte[] binaryHeader() {
114 		return "FisI".getBytes();
115 	}
116 
117 	@Override
118 	public void writeBinary(DataOutput out) throws IOException {
119 		out.writeInt(width);
120 		out.writeInt(height);
121 		out.writeInt(numComponents);
122 	}
123 
124 	/**
125 	 * Train on a map of data.
126 	 * 
127 	 * @param data
128 	 *            the data
129 	 */
130 	public void train(Map<?, ? extends List<FImage>> data) {
131 		final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
132 
133 		for (final Entry<?, ? extends List<FImage>> e : data.entrySet()) {
134 			for (final FImage i : e.getValue()) {
135 				list.add(IndependentPair.pair(e.getKey(), i));
136 			}
137 		}
138 
139 		train(list);
140 	}
141 
142 	/**
143 	 * Train on a grouped dataset.
144 	 * 
145 	 * @param <KEY>
146 	 *            The group type
147 	 * @param data
148 	 *            the data
149 	 */
150 	public <KEY> void train(GroupedDataset<KEY, ? extends ListDataset<FImage>, FImage> data) {
151 		final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
152 
153 		for (final KEY e : data.getGroups()) {
154 			for (final FImage i : data.getInstances(e)) {
155 				list.add(IndependentPair.pair(e, i));
156 			}
157 		}
158 
159 		train(list);
160 	}
161 
162 	@Override
163 	public void train(List<? extends IndependentPair<?, FImage>> data) {
164 		width = data.get(0).secondObject().width;
165 		height = data.get(0).secondObject().height;
166 
167 		final Map<Object, List<double[]>> mapData = new HashMap<Object, List<double[]>>();
168 		final List<double[]> listData = new ArrayList<double[]>();
169 		for (final IndependentPair<?, FImage> item : data) {
170 			List<double[]> fvs = mapData.get(item.firstObject());
171 			if (fvs == null)
172 				mapData.put(item.firstObject(), fvs = new ArrayList<double[]>());
173 
174 			final double[] fv = FImage2DoubleFV.INSTANCE.extractFeature(item.getSecondObject()).values;
175 			fvs.add(fv);
176 			listData.add(fv);
177 		}
178 
179 		final PrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(numComponents);
180 		pca.learnBasis(listData);
181 
182 		final List<double[][]> ldaData = new ArrayList<double[][]>(mapData.size());
183 		for (final Entry<?, List<double[]>> e : mapData.entrySet()) {
184 			final List<double[]> vecs = e.getValue();
185 			final double[][] classData = new double[vecs.size()][];
186 
187 			for (int i = 0; i < classData.length; i++) {
188 				classData[i] = pca.project(vecs.get(i));
189 			}
190 
191 			ldaData.add(classData);
192 		}
193 
194 		final LinearDiscriminantAnalysis lda = new LinearDiscriminantAnalysis(numComponents);
195 		lda.learnBasis(ldaData);
196 
197 		basis = pca.getBasis().times(lda.getBasis());
198 		mean = pca.getMean();
199 	}
200 
201 	private double[] project(double[] vector) {
202 		final Matrix vec = new Matrix(1, vector.length);
203 		final double[][] vecarr = vec.getArray();
204 
205 		for (int i = 0; i < vector.length; i++)
206 			vecarr[0][i] = vector[i] - mean[i];
207 
208 		return vec.times(basis).getColumnPackedCopy();
209 	}
210 
211 	@Override
212 	public DoubleFV extractFeature(FImage object) {
213 		return new DoubleFV(project(FImage2DoubleFV.INSTANCE.extractFeature(object).values));
214 	}
215 
216 	/**
217 	 * Get a specific basis vector as a double array. The returned array
218 	 * contains a copy of the data.
219 	 * 
220 	 * @param index
221 	 *            the index of the vector
222 	 * 
223 	 * @return the eigenvector
224 	 */
225 	public double[] getBasisVector(int index) {
226 		final double[] pc = new double[basis.getRowDimension()];
227 		final double[][] data = basis.getArray();
228 
229 		for (int r = 0; r < pc.length; r++)
230 			pc[r] = data[r][index];
231 
232 		return pc;
233 	}
234 
235 	/**
236 	 * Draw an eigenvector as an image
237 	 * 
238 	 * @param num
239 	 *            the index of the eigenvector to draw.
240 	 * @return an image showing the eigenvector.
241 	 */
242 	public FImage visualise(int num) {
243 		return new FImage(ArrayUtils.reshapeFloat(getBasisVector(num), width, height));
244 	}
245 }