1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.image.objectdetection.haar.training;
31
32 import java.util.List;
33
34 import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
35 import org.openimaj.image.objectdetection.haar.HaarFeature;
36 import org.openimaj.ml.classification.LabelledDataProvider;
37 import org.openimaj.util.array.ArrayUtils;
38
39 public class BasicTrainingData implements LabelledDataProvider {
40 SummedSqTiltAreaTable[] sats;
41 boolean[] classes;
42 HaarFeature[] features;
43
44 public BasicTrainingData(List<SummedSqTiltAreaTable> positive, List<SummedSqTiltAreaTable> negative,
45 List<HaarFeature> features)
46 {
47 sats = new SummedSqTiltAreaTable[positive.size() + negative.size()];
48 classes = new boolean[sats.length];
49
50 int count = 0;
51 for (final SummedSqTiltAreaTable t : positive) {
52 sats[count] = t;
53 classes[count] = true;
54 ++count;
55 }
56
57 for (final SummedSqTiltAreaTable t : negative) {
58 sats[count] = t;
59 classes[count] = false;
60 ++count;
61 }
62
63 this.features = features.toArray(new HaarFeature[features.size()]);
64 }
65
66 @Override
67 public float[] getFeatureResponse(int dimension) {
68 final float[] response = new float[sats.length];
69
70 for (int i = 0; i < sats.length; i++) {
71 final float wvNorm = computeWindowVarianceNorm(sats[i]);
72
73 response[i] = features[dimension].computeResponse(sats[i], 0, 0) / wvNorm;
74 }
75
76 return response;
77 }
78
79 @Override
80 public boolean[] getClasses() {
81 return classes;
82 }
83
84 @Override
85 public int numInstances() {
86 return classes.length;
87 }
88
89 @Override
90 public int numDimensions() {
91 return features.length;
92 }
93
94 float computeWindowVarianceNorm(SummedSqTiltAreaTable sat) {
95 final int w = sat.sum.width - 1 - 2;
96 final int h = sat.sum.height - 1 - 2;
97
98 final int x = 1;
99 final int y = 1;
100
101 final float sum = sat.sum.pixels[y + h][x + w] + sat.sum.pixels[y][x] -
102 sat.sum.pixels[y + h][x] - sat.sum.pixels[y][x + w];
103 final float sqSum = sat.sqSum.pixels[y + w][x + w] + sat.sqSum.pixels[y][x] -
104 sat.sqSum.pixels[y + w][x] - sat.sqSum.pixels[y][x + w];
105
106 final float cachedInvArea = 1.0f / (w * h);
107 final float mean = sum * cachedInvArea;
108 float wvNorm = sqSum * cachedInvArea - mean * mean;
109 wvNorm = (float) ((wvNorm >= 0) ? Math.sqrt(wvNorm) : 1);
110
111 return wvNorm;
112 }
113
114 @Override
115 public float[] getInstanceFeature(int idx) {
116 final float[] feature = new float[features.length];
117 final SummedSqTiltAreaTable sat = sats[idx];
118
119 final float wvNorm = computeWindowVarianceNorm(sat);
120
121 for (int i = 0; i < features.length; i++) {
122 feature[i] = features[i].computeResponse(sat, 0, 0) / wvNorm;
123 }
124
125 return feature;
126 }
127
128 @Override
129 public int[] getSortedResponseIndices(int d) {
130 return ArrayUtils.indexSort(getFeatureResponse(d));
131 }
132
133 public HaarFeature getFeature(int dimension) {
134 return features[dimension];
135 }
136 }