001package org.openimaj.demos.sandbox.tldcpp.detector; 002 003import java.util.ArrayList; 004import java.util.List; 005 006import org.openimaj.image.FImage; 007import org.openimaj.image.analysis.algorithm.TemplateMatcher; 008import org.openimaj.math.geometry.shape.Rectangle; 009 010/** 011 * The third and most powerful, but equally most slow parts of the 012 * {@link DetectorCascade}. Holding a list of falsePositives and truePositives, 013 * a classification score can be ascribed to a new patch which can be used as a 014 * confidence that a given patch is positive. This is calculated using the 015 * correlation between the new patch and the false positive and falst negatives 016 * such that: 017 * 018 * confidence = dP / (dN + dP) 019 * 020 * and dP = max(corr(patch,truePositives)) dP = max(corr(patch,falsePositives)) 021 * 022 * if no true positives have been seen, classify will always return 0 if not 023 * false positives have been seen, classify will always return 1 024 * 025 * classify is used by filter such that if the confidence of a patch is larger 026 * than {@link #thetaTP} the patch is though to be a good patch for the object. 027 * 028 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 029 * 030 */ 031public class NNClassifier { 032 /** 033 * whether this stage is enabled 034 */ 035 public boolean enabled; 036 037 /** 038 * Used as the lower bound of a historysis threshold (i.e. if a detection is 039 * made with a confidence over {@link #thetaTP}, the next detection can be a 040 * little worse by matching this) 041 */ 042 public float thetaFP; 043 /** 044 * Used as the upper bound threshold 045 */ 046 public float thetaTP; 047 ScaleIndexRectangle[] windows; 048 DetectionResult detectionResult; 049 List<NormalizedPatch> falsePositives; 050 List<NormalizedPatch> truePositives; 051 052 /** 053 * Sets thetaFP as 0.5f and thetaTP as .65f 054 */ 055 public NNClassifier() { 056 thetaFP = .5f; 057 thetaTP = .65f; 058 059 truePositives = new ArrayList<NormalizedPatch>(); 060 falsePositives = new ArrayList<NormalizedPatch>(); 061 062 } 063 064 /** 065 * clear falst positives and true positives 066 */ 067 public void release() { 068 falsePositives.clear(); 069 truePositives.clear(); 070 } 071 072 /** 073 * 074 * @param f1 075 * @param f2 076 * @return correlation between two patches (assumed to be the same size) 077 * calculated using {@link TemplateMatcherMode} 078 */ 079 private float ncc(FImage f1, FImage f2) { 080 final float normcorr = TemplateMatcher.Mode.NORM_CORRELATION.computeMatchScore(f1.pixels, 0, 0, f2.pixels, 0, 0, 081 f1.width, f1.height); 082 return normcorr; 083 } 084 085 /** 086 * 087 * @param patch 088 * @return The confidence that a given patch is a postive 089 */ 090 public float classifyPatch(NormalizedPatch patch) { 091 092 if (truePositives.isEmpty()) { 093 return 0; 094 } 095 096 if (falsePositives.isEmpty()) { 097 return 1; 098 } 099 100 float ccorr_max_p = 0; 101 // Compare patch to positive patches 102 for (int i = 0; i < truePositives.size(); i++) { 103 final float ccorr = ncc(truePositives.get(i).normalisedPatch, patch.normalisedPatch); 104 if (ccorr > ccorr_max_p) { 105 ccorr_max_p = ccorr; 106 } 107 } 108 109 float ccorr_max_n = 0; 110 // Compare patch to positive patches 111 for (int i = 0; i < falsePositives.size(); i++) { 112 final float ccorr = ncc(falsePositives.get(i).normalisedPatch, patch.normalisedPatch); 113 if (ccorr > ccorr_max_n) { 114 ccorr_max_n = ccorr; 115 } 116 } 117 118 final float dN = 1 - ccorr_max_n; 119 final float dP = 1 - ccorr_max_p; 120 121 final float distance = dN / (dN + dP); 122 return distance; 123 } 124 125 /** 126 * @param img 127 * @param bb 128 * @return confidence of the bb in image 129 */ 130 public float classifyBB(FImage img, Rectangle bb) { 131 final NormalizedPatch patch = new NormalizedPatch(); 132 patch.source = img; 133 patch.window = bb; 134 patch.prepareNormalisedPatch(); 135 return classifyPatch(patch); 136 137 } 138 139 float classifyWindow(FImage img, int windowIdx) { 140 141 final ScaleIndexRectangle bbox = windows[windowIdx]; 142 final NormalizedPatch patch = new NormalizedPatch(); 143 patch.window = bbox; 144 patch.source = img; 145 // here we reuse the scales images as the patch of the right 146 // width/height and just write into it. 147 patch.normalisedPatch = patch.zoomAndNormaliseTo(NormalizedPatch.SLUT_WORKSPACE); 148 149 return classifyPatch(patch); 150 } 151 152 /** 153 * @param img 154 * @param windowIdx 155 * @return Filter a window by getting its confidence and returning true if 156 * confidence > thetaTP 157 */ 158 public boolean filter(FImage img, int windowIdx) { 159 if (!enabled) 160 return true; 161 162 final float conf = classifyWindow(img, windowIdx); 163 164 if (conf < thetaTP) { 165 return false; 166 } 167 168 return true; 169 } 170 171 /** 172 * Given a list of patches, classify each patch. If the patch is said to be 173 * positive and has a confidence lower than {@link #thetaTP} add the patch 174 * to the true positives If the patch is said to be negative and has a 175 * confidence higher than {@link #thetaFP} add the patch to the false 176 * positives 177 * 178 * @param patches 179 */ 180 public void learn(List<NormalizedPatch> patches) { 181 // TODO: Randomization might be a good idea here 182 for (int i = 0; i < patches.size(); i++) { 183 184 final NormalizedPatch patch = patches.get(i); 185 // if the patch is a negative one, the image has not been normalised 186 // etc yet! 187 // it uses the prepared windows, so a held scale patch can be used 188 if (!patch.positive) { 189 patch.normalisedPatch = patch.zoomAndNormaliseTo(NormalizedPatch.SLUT_WORKSPACE); 190 } 191 192 final float conf = classifyPatch(patch); 193 194 if (patch.positive && conf <= thetaTP) { 195 truePositives.add(patch); 196 } 197 198 if (!patch.positive && conf >= thetaFP) { 199 // We must handle the SLUT_WORKSPACE! 200 // If we're negative we are using the slut, if we're planning to 201 // keep this negative we must NOW clone the slut 202 patch.normalisedPatch = patch.normalisedPatch.clone(); 203 falsePositives.add(patch); 204 } 205 } 206 207 } 208 209 /** 210 * @return the positively classified patches 211 */ 212 public List<NormalizedPatch> getPositivePatches() { 213 return this.truePositives; 214 } 215 216}