001package org.openimaj.demos.sandbox.tldcpp.detector;
002
003import java.util.ArrayList;
004import java.util.List;
005
006import org.openimaj.image.FImage;
007import org.openimaj.image.analysis.algorithm.TemplateMatcher;
008import org.openimaj.math.geometry.shape.Rectangle;
009
010/**
011 * The third and most powerful, but equally most slow parts of the
012 * {@link DetectorCascade}. Holding a list of falsePositives and truePositives,
013 * a classification score can be ascribed to a new patch which can be used as a
014 * confidence that a given patch is positive. This is calculated using the
015 * correlation between the new patch and the false positive and falst negatives
016 * such that:
017 * 
018 * confidence = dP / (dN + dP)
019 * 
020 * and dP = max(corr(patch,truePositives)) dP = max(corr(patch,falsePositives))
021 * 
022 * if no true positives have been seen, classify will always return 0 if not
023 * false positives have been seen, classify will always return 1
024 * 
025 * classify is used by filter such that if the confidence of a patch is larger
026 * than {@link #thetaTP} the patch is though to be a good patch for the object.
027 * 
028 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
029 * 
030 */
031public class NNClassifier {
032        /**
033         * whether this stage is enabled
034         */
035        public boolean enabled;
036
037        /**
038         * Used as the lower bound of a historysis threshold (i.e. if a detection is
039         * made with a confidence over {@link #thetaTP}, the next detection can be a
040         * little worse by matching this)
041         */
042        public float thetaFP;
043        /**
044         * Used as the upper bound threshold
045         */
046        public float thetaTP;
047        ScaleIndexRectangle[] windows;
048        DetectionResult detectionResult;
049        List<NormalizedPatch> falsePositives;
050        List<NormalizedPatch> truePositives;
051
052        /**
053         * Sets thetaFP as 0.5f and thetaTP as .65f
054         */
055        public NNClassifier() {
056                thetaFP = .5f;
057                thetaTP = .65f;
058
059                truePositives = new ArrayList<NormalizedPatch>();
060                falsePositives = new ArrayList<NormalizedPatch>();
061
062        }
063
064        /**
065         * clear falst positives and true positives
066         */
067        public void release() {
068                falsePositives.clear();
069                truePositives.clear();
070        }
071
072        /**
073         * 
074         * @param f1
075         * @param f2
076         * @return correlation between two patches (assumed to be the same size)
077         *         calculated using {@link TemplateMatcherMode}
078         */
079        private float ncc(FImage f1, FImage f2) {
080                final float normcorr = TemplateMatcher.Mode.NORM_CORRELATION.computeMatchScore(f1.pixels, 0, 0, f2.pixels, 0, 0,
081                                f1.width, f1.height);
082                return normcorr;
083        }
084
085        /**
086         * 
087         * @param patch
088         * @return The confidence that a given patch is a postive
089         */
090        public float classifyPatch(NormalizedPatch patch) {
091
092                if (truePositives.isEmpty()) {
093                        return 0;
094                }
095
096                if (falsePositives.isEmpty()) {
097                        return 1;
098                }
099
100                float ccorr_max_p = 0;
101                // Compare patch to positive patches
102                for (int i = 0; i < truePositives.size(); i++) {
103                        final float ccorr = ncc(truePositives.get(i).normalisedPatch, patch.normalisedPatch);
104                        if (ccorr > ccorr_max_p) {
105                                ccorr_max_p = ccorr;
106                        }
107                }
108
109                float ccorr_max_n = 0;
110                // Compare patch to positive patches
111                for (int i = 0; i < falsePositives.size(); i++) {
112                        final float ccorr = ncc(falsePositives.get(i).normalisedPatch, patch.normalisedPatch);
113                        if (ccorr > ccorr_max_n) {
114                                ccorr_max_n = ccorr;
115                        }
116                }
117
118                final float dN = 1 - ccorr_max_n;
119                final float dP = 1 - ccorr_max_p;
120
121                final float distance = dN / (dN + dP);
122                return distance;
123        }
124
125        /**
126         * @param img
127         * @param bb
128         * @return confidence of the bb in image
129         */
130        public float classifyBB(FImage img, Rectangle bb) {
131                final NormalizedPatch patch = new NormalizedPatch();
132                patch.source = img;
133                patch.window = bb;
134                patch.prepareNormalisedPatch();
135                return classifyPatch(patch);
136
137        }
138
139        float classifyWindow(FImage img, int windowIdx) {
140
141                final ScaleIndexRectangle bbox = windows[windowIdx];
142                final NormalizedPatch patch = new NormalizedPatch();
143                patch.window = bbox;
144                patch.source = img;
145                // here we reuse the scales images as the patch of the right
146                // width/height and just write into it.
147                patch.normalisedPatch = patch.zoomAndNormaliseTo(NormalizedPatch.SLUT_WORKSPACE);
148
149                return classifyPatch(patch);
150        }
151
152        /**
153         * @param img
154         * @param windowIdx
155         * @return Filter a window by getting its confidence and returning true if
156         *         confidence > thetaTP
157         */
158        public boolean filter(FImage img, int windowIdx) {
159                if (!enabled)
160                        return true;
161
162                final float conf = classifyWindow(img, windowIdx);
163
164                if (conf < thetaTP) {
165                        return false;
166                }
167
168                return true;
169        }
170
171        /**
172         * Given a list of patches, classify each patch. If the patch is said to be
173         * positive and has a confidence lower than {@link #thetaTP} add the patch
174         * to the true positives If the patch is said to be negative and has a
175         * confidence higher than {@link #thetaFP} add the patch to the false
176         * positives
177         * 
178         * @param patches
179         */
180        public void learn(List<NormalizedPatch> patches) {
181                // TODO: Randomization might be a good idea here
182                for (int i = 0; i < patches.size(); i++) {
183
184                        final NormalizedPatch patch = patches.get(i);
185                        // if the patch is a negative one, the image has not been normalised
186                        // etc yet!
187                        // it uses the prepared windows, so a held scale patch can be used
188                        if (!patch.positive) {
189                                patch.normalisedPatch = patch.zoomAndNormaliseTo(NormalizedPatch.SLUT_WORKSPACE);
190                        }
191
192                        final float conf = classifyPatch(patch);
193
194                        if (patch.positive && conf <= thetaTP) {
195                                truePositives.add(patch);
196                        }
197
198                        if (!patch.positive && conf >= thetaFP) {
199                                // We must handle the SLUT_WORKSPACE!
200                                // If we're negative we are using the slut, if we're planning to
201                                // keep this negative we must NOW clone the slut
202                                patch.normalisedPatch = patch.normalisedPatch.clone();
203                                falsePositives.add(patch);
204                        }
205                }
206
207        }
208
209        /**
210         * @return the positively classified patches
211         */
212        public List<NormalizedPatch> getPositivePatches() {
213                return this.truePositives;
214        }
215
216}