1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.image.objectdetection.haar.training;
31
32 import java.util.List;
33
34 import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
35 import org.openimaj.image.objectdetection.haar.HaarFeature;
36 import org.openimaj.ml.classification.LabelledDataProvider;
37 import org.openimaj.util.array.ArrayUtils;
38 import org.openimaj.util.function.Operation;
39 import org.openimaj.util.parallel.Parallel;
40
41 public class CachedTrainingData implements LabelledDataProvider {
42 float[][] responses;
43 boolean[] classes;
44 int[][] sortedIndices;
45 List<HaarFeature> features;
46 int width, height;
47
48 float computeWindowVarianceNorm(SummedSqTiltAreaTable sat) {
49 final int w = width - 2;
50 final int h = height - 2;
51
52 final int x = 1;
53 final int y = 1;
54
55 final float sum = sat.sum.pixels[y + h][x + w] + sat.sum.pixels[y][x] -
56 sat.sum.pixels[y + h][x] - sat.sum.pixels[y][x + w];
57 final float sqSum = sat.sqSum.pixels[y + w][x + w] + sat.sqSum.pixels[y][x] -
58 sat.sqSum.pixels[y + w][x] - sat.sqSum.pixels[y][x + w];
59
60 final float cachedInvArea = 1.0f / (w * h);
61 final float mean = sum * cachedInvArea;
62 float wvNorm = sqSum * cachedInvArea - mean * mean;
63 wvNorm = (float) ((wvNorm > 0) ? Math.sqrt(wvNorm) : 1);
64
65 return wvNorm;
66 }
67
68 public CachedTrainingData(final List<SummedSqTiltAreaTable> positive, final List<SummedSqTiltAreaTable> negative,
69 final List<HaarFeature> features)
70 {
71 this.width = positive.get(0).sum.width - 1;
72 this.height = positive.get(0).sum.height - 1;
73
74 this.features = features;
75 final int nfeatures = features.size();
76
77 classes = new boolean[positive.size() + negative.size()];
78 responses = new float[nfeatures][classes.length];
79 sortedIndices = new int[nfeatures][];
80
81
82 Parallel.forIndex(0, nfeatures, 1, new Operation<Integer>() {
83
84 @Override
85 public void perform(Integer f) {
86 final HaarFeature feature = features.get(f);
87 int count = 0;
88
89 for (final SummedSqTiltAreaTable t : positive) {
90 final float wvNorm = computeWindowVarianceNorm(t);
91 responses[f][count] = feature.computeResponse(t, 0, 0) / wvNorm;
92 classes[count] = true;
93 ++count;
94 }
95
96 for (final SummedSqTiltAreaTable t : negative) {
97 final float wvNorm = computeWindowVarianceNorm(t);
98 responses[f][count] = feature.computeResponse(t, 0, 0) / wvNorm;
99 classes[count] = false;
100 ++count;
101 }
102
103 sortedIndices[f] = ArrayUtils.indexSort(responses[f]);
104 }
105 });
106 }
107
108 @Override
109 public float[] getFeatureResponse(int dimension) {
110 return responses[dimension];
111 }
112
113 @Override
114 public boolean[] getClasses() {
115 return classes;
116 }
117
118 @Override
119 public int numInstances() {
120 return classes.length;
121 }
122
123 @Override
124 public int numDimensions() {
125 return responses.length;
126 }
127
128 @Override
129 public float[] getInstanceFeature(int idx) {
130 final float[] feature = new float[responses.length];
131
132 for (int i = 0; i < feature.length; i++) {
133 feature[i] = responses[i][idx];
134 }
135
136 return feature;
137 }
138
139 @Override
140 public int[] getSortedResponseIndices(int d) {
141 return sortedIndices[d];
142 }
143
144 public HaarFeature getFeature(int dimension) {
145 return features.get(dimension);
146 }
147 }