1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.image.model;
31
32 import java.io.DataInput;
33 import java.io.DataOutput;
34 import java.io.IOException;
35 import java.util.ArrayList;
36 import java.util.HashMap;
37 import java.util.List;
38 import java.util.Map;
39 import java.util.Map.Entry;
40
41 import org.openimaj.citation.annotation.Reference;
42 import org.openimaj.citation.annotation.ReferenceType;
43 import org.openimaj.data.dataset.GroupedDataset;
44 import org.openimaj.data.dataset.ListDataset;
45 import org.openimaj.feature.DoubleFV;
46 import org.openimaj.feature.FeatureExtractor;
47 import org.openimaj.image.FImage;
48 import org.openimaj.image.feature.FImage2DoubleFV;
49 import org.openimaj.io.ReadWriteableBinary;
50 import org.openimaj.math.matrix.algorithm.LinearDiscriminantAnalysis;
51 import org.openimaj.math.matrix.algorithm.pca.PrincipalComponentAnalysis;
52 import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
53 import org.openimaj.ml.training.BatchTrainer;
54 import org.openimaj.util.array.ArrayUtils;
55 import org.openimaj.util.pair.IndependentPair;
56
57 import Jama.Matrix;
58
59
60
61
62
63
64
65 @Reference(
66 type = ReferenceType.Article,
67 author = { "Belhumeur, Peter N.", "Hespanha, Jo\\~{a}o P.", "Kriegman, David J." },
68 title = "Eigenfaces vs. Fisherfaces: Recognition Using Class Specific Linear Projection",
69 year = "1997",
70 journal = "IEEE Trans. Pattern Anal. Mach. Intell.",
71 pages = { "711", "", "720" },
72 url = "http://dx.doi.org/10.1109/34.598228",
73 month = "July",
74 number = "7",
75 publisher = "IEEE Computer Society",
76 volume = "19",
77 customData = {
78 "issn", "0162-8828",
79 "numpages", "10",
80 "doi", "10.1109/34.598228",
81 "acmid", "261512",
82 "address", "Washington, DC, USA",
83 "keywords", "Appearance-based vision, face recognition, illumination invariance, Fisher's linear discriminant."
84 })
85 public class FisherImages implements BatchTrainer<IndependentPair<?, FImage>>,
86 FeatureExtractor<DoubleFV, FImage>,
87 ReadWriteableBinary
88 {
89 private int numComponents;
90 private int width;
91 private int height;
92 private Matrix basis;
93 private double[] mean;
94
95
96
97
98
99
100
101 public FisherImages(int numComponents) {
102 this.numComponents = numComponents;
103 }
104
105 @Override
106 public void readBinary(DataInput in) throws IOException {
107 width = in.readInt();
108 height = in.readInt();
109 numComponents = in.readInt();
110 }
111
112 @Override
113 public byte[] binaryHeader() {
114 return "FisI".getBytes();
115 }
116
117 @Override
118 public void writeBinary(DataOutput out) throws IOException {
119 out.writeInt(width);
120 out.writeInt(height);
121 out.writeInt(numComponents);
122 }
123
124
125
126
127
128
129
130 public void train(Map<?, ? extends List<FImage>> data) {
131 final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
132
133 for (final Entry<?, ? extends List<FImage>> e : data.entrySet()) {
134 for (final FImage i : e.getValue()) {
135 list.add(IndependentPair.pair(e.getKey(), i));
136 }
137 }
138
139 train(list);
140 }
141
142
143
144
145
146
147
148
149
150 public <KEY> void train(GroupedDataset<KEY, ? extends ListDataset<FImage>, FImage> data) {
151 final List<IndependentPair<?, FImage>> list = new ArrayList<IndependentPair<?, FImage>>();
152
153 for (final KEY e : data.getGroups()) {
154 for (final FImage i : data.getInstances(e)) {
155 list.add(IndependentPair.pair(e, i));
156 }
157 }
158
159 train(list);
160 }
161
162 @Override
163 public void train(List<? extends IndependentPair<?, FImage>> data) {
164 width = data.get(0).secondObject().width;
165 height = data.get(0).secondObject().height;
166
167 final Map<Object, List<double[]>> mapData = new HashMap<Object, List<double[]>>();
168 final List<double[]> listData = new ArrayList<double[]>();
169 for (final IndependentPair<?, FImage> item : data) {
170 List<double[]> fvs = mapData.get(item.firstObject());
171 if (fvs == null)
172 mapData.put(item.firstObject(), fvs = new ArrayList<double[]>());
173
174 final double[] fv = FImage2DoubleFV.INSTANCE.extractFeature(item.getSecondObject()).values;
175 fvs.add(fv);
176 listData.add(fv);
177 }
178
179 final PrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(numComponents);
180 pca.learnBasis(listData);
181
182 final List<double[][]> ldaData = new ArrayList<double[][]>(mapData.size());
183 for (final Entry<?, List<double[]>> e : mapData.entrySet()) {
184 final List<double[]> vecs = e.getValue();
185 final double[][] classData = new double[vecs.size()][];
186
187 for (int i = 0; i < classData.length; i++) {
188 classData[i] = pca.project(vecs.get(i));
189 }
190
191 ldaData.add(classData);
192 }
193
194 final LinearDiscriminantAnalysis lda = new LinearDiscriminantAnalysis(numComponents);
195 lda.learnBasis(ldaData);
196
197 basis = pca.getBasis().times(lda.getBasis());
198 mean = pca.getMean();
199 }
200
201 private double[] project(double[] vector) {
202 final Matrix vec = new Matrix(1, vector.length);
203 final double[][] vecarr = vec.getArray();
204
205 for (int i = 0; i < vector.length; i++)
206 vecarr[0][i] = vector[i] - mean[i];
207
208 return vec.times(basis).getColumnPackedCopy();
209 }
210
211 @Override
212 public DoubleFV extractFeature(FImage object) {
213 return new DoubleFV(project(FImage2DoubleFV.INSTANCE.extractFeature(object).values));
214 }
215
216
217
218
219
220
221
222
223
224
225 public double[] getBasisVector(int index) {
226 final double[] pc = new double[basis.getRowDimension()];
227 final double[][] data = basis.getArray();
228
229 for (int r = 0; r < pc.length; r++)
230 pc[r] = data[r][index];
231
232 return pc;
233 }
234
235
236
237
238
239
240
241
242 public FImage visualise(int num) {
243 return new FImage(ArrayUtils.reshapeFloat(getBasisVector(num), width, height));
244 }
245 }