001package org.openimaj.demos.sandbox.tldcpp; 002 003import java.util.ArrayList; 004import java.util.Collections; 005import java.util.Comparator; 006import java.util.List; 007import java.util.Random; 008 009import org.openimaj.demos.sandbox.tldcpp.detector.Clustering; 010import org.openimaj.demos.sandbox.tldcpp.detector.DetectionResult; 011import org.openimaj.demos.sandbox.tldcpp.detector.DetectorCascade; 012import org.openimaj.demos.sandbox.tldcpp.detector.EnsembleClassifier; 013import org.openimaj.demos.sandbox.tldcpp.detector.NNClassifier; 014import org.openimaj.demos.sandbox.tldcpp.detector.NormalizedPatch; 015import org.openimaj.demos.sandbox.tldcpp.detector.VarianceFilter; 016import org.openimaj.demos.sandbox.tldcpp.tracker.MedianFlowTracker; 017import org.openimaj.demos.sandbox.tldcpp.videotld.TLDUtil; 018import org.openimaj.image.FImage; 019import org.openimaj.math.geometry.shape.Rectangle; 020import org.openimaj.util.pair.IndependentPair; 021 022/** 023 * An implementation TLD tracker by Zdenek Kalal: 024 * http://info.ee.surrey.ac.uk/Personal/Z.Kalal/tld.html based on the C++ 025 * implementation Georg Nebehay: http://gnebehay.github.com/OpenTLD/ 026 * 027 * This class is the main controller class. TLD is instantiated on an image and 028 * bounding box. Once the detector classifiers are initialised the 029 * {@link TLD#processImage(FImage)} function must be called with suceutive 030 * frames in which objects are: - Tracked using {@link MedianFlowTracker} - 031 * ...and if not tracked correctly detected using the {@link DetectorCascade}. - 032 * ... if tracked or detected correctly, but the object is different enough, it 033 * is learnt using {@link DetectorCascade}! 034 * 035 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 036 * 037 */ 038public class TLD { 039 /** 040 * Whether the {@link MedianFlowTracker} is enabled 041 */ 042 public boolean trackerEnabled; 043 /** 044 * Whether the {@link DetectorCascade} is enabled 045 */ 046 public boolean detectorEnabled; 047 /** 048 * Whether previously unseen frames are learnt from 049 */ 050 public boolean learningEnabled; 051 /** 052 * Whether some frames are skipped 053 */ 054 public boolean alternating; 055 /** 056 * The current bounding box 057 */ 058 public Rectangle currBB; 059 /** 060 * the previous bounding box 061 */ 062 public Rectangle prevBB; 063 064 /** 065 * The detector 066 */ 067 public DetectorCascade detectorCascade; 068 069 /** 070 * The tracker 071 */ 072 public MedianFlowTracker medianFlowTracker; 073 /** 074 * The previous frame, where #prevBB was detected 075 */ 076 public FImage prevImg; 077 /** 078 * The current frame, where #currBB will be detected 079 */ 080 public FImage currImg; 081 /** 082 * The confidence of the current bounding box. Calculated from the detector 083 * or the tracker. The {@link MedianFlowTracker#trackerBB} is extracted from 084 * the {@link #currImg} and confidence is gauged using {@link NNClassifier} 085 */ 086 public float currConf; 087 088 /** 089 * The nearest neighbour classifier 090 */ 091 private NNClassifier nnClassifier; 092 private boolean learning; 093 private boolean valid; 094 private boolean wasValid; 095 private int imgWidth; 096 private int imgHeight; 097 098 /** 099 * Initialises the TLD with a series of defaults. {@link #trackerEnabled} is 100 * true {@link #detectorEnabled} is true {@link #learningEnabled} is true 101 * {@link #alternating} is false {@link #alternating} is false valid 102 */ 103 private TLD() { 104 trackerEnabled = true; 105 detectorEnabled = true; 106 learningEnabled = true; 107 alternating = false; 108 valid = false; 109 wasValid = false; 110 learning = false; 111 currBB = null; 112 113 detectorCascade = new DetectorCascade(); 114 nnClassifier = detectorCascade.getNNClassifier(); 115 116 medianFlowTracker = new MedianFlowTracker(); 117 } 118 119 /** 120 * @param width 121 * @param height 122 */ 123 public TLD(int width, int height) { 124 this(); 125 this.imgWidth = width; 126 this.imgHeight = height; 127 } 128 129 /** 130 * Stop tracking whatever is currently being tracked 131 */ 132 public void release() { 133 detectorCascade.release(); 134 medianFlowTracker.cleanPreviousData(); 135 currBB = null; 136 } 137 138 // 139 private void storeCurrentData() { 140 prevImg = null; 141 prevImg = currImg; // Store old image (if any) 142 prevBB = currBB; // Store old bounding box (if any) 143 144 detectorCascade.cleanPreviousData(); // Reset detector results 145 medianFlowTracker.cleanPreviousData(); 146 147 wasValid = valid; 148 } 149 150 /** 151 * Set the current object being tracked. Initialilise the detector casecade 152 * using {@link DetectorCascade#init()}. The {@link #initialLearning()} is 153 * called 154 * 155 * @param img 156 * @param bb 157 * @throws Exception 158 */ 159 public void selectObject(FImage img, Rectangle bb) throws Exception { 160 // Delete old object 161 detectorCascade.release(); 162 163 detectorCascade.setObjWidth((int) bb.width); 164 detectorCascade.setObjHeight((int) bb.height); 165 detectorCascade.setImgWidth(this.imgWidth); 166 detectorCascade.setImgHeight(this.imgHeight); 167 168 // Init detector cascade 169 detectorCascade.init(); 170 171 currImg = img; 172 currBB = bb; 173 currConf = 1; 174 valid = true; 175 176 initialLearning(); 177 178 } 179 180 /** 181 * An attempt is made to track the object from the previous frame. The 182 * {@link DetectorCascade} instance is used regardless to detect the object. 183 * The {@link #fuseHypotheses()} is then used to combine the estimate of the 184 * tracker and detector together. Finally, using the detectorCascade the 185 * learn function is called and the classifier is improved 186 * 187 * @param img 188 */ 189 public void processImage(FImage img) { 190 storeCurrentData(); 191 final FImage grey_frame = img.clone(); // Store new image , right after 192 // storeCurrentData(); 193 currImg = grey_frame; 194 // 195 if (trackerEnabled) { 196 medianFlowTracker.track(prevImg, currImg, prevBB); 197 } 198 // 199 if (detectorEnabled && (!alternating || medianFlowTracker.trackerBB == null)) { 200 detectorCascade.detect(grey_frame); 201 } 202 // 203 fuseHypotheses(); 204 // 205 learn(); 206 207 // if(!valid){ 208 // currBB = null; 209 // currConf = 0; 210 // } 211 // 212 } 213 214 /** 215 * The bounding box is retrieved from the tracker and detector as well as 216 * the number of clusters detected by the detector {@link Clustering} step. 217 * If exactly one cluster exists in {@link Clustering} (i.e. the detector is 218 * very sure) the detector confidence is included. If the tracker was able 219 * to keep track of the bounding box (i.e. trackerBB is not null) then the 220 * tracker confidence is combined. 221 * 222 * if the detector is more confident than the tracker and their overlap is 223 * very small, the detectors BB is used. Otherwise the trackers BB and 224 * confidence is used. If the trackerBB is used the tracking is valid only 225 * if the tracking was invalid last time and the confidence is above 226 * {@link NNClassifier#thetaTP} or if the tracking was valid last time a 227 * smaller threshold of {@link NNClassifier#thetaFP} is used. 228 * 229 * TODO: Maybe a better combination of the two bounding boxes from the 230 * detector and tracker would be better? 231 */ 232 public void fuseHypotheses() { 233 final Rectangle trackerBB = medianFlowTracker.trackerBB; 234 final int numClusters = detectorCascade.getDetectionResult().numClusters; 235 final Rectangle detectorBB = detectorCascade.getDetectionResult().detectorBB; 236 237 currBB = null; 238 currConf = 0; 239 valid = false; 240 241 float confDetector = 0; 242 243 if (numClusters == 1) { 244 confDetector = nnClassifier.classifyBB(currImg, detectorBB); 245 } 246 247 if (trackerBB != null) { 248 final float confTracker = nnClassifier.classifyBB(currImg, trackerBB); 249 250 if (numClusters == 1 && confDetector > confTracker 251 && TLDUtil.tldOverlapNorm(trackerBB, detectorBB) < 0.5) 252 { 253 254 currBB = detectorBB.clone(); 255 currConf = confDetector; 256 } else { 257 currBB = trackerBB.clone(); 258 currConf = confTracker; 259 if (confTracker > nnClassifier.thetaTP) { 260 valid = true; 261 } else if (wasValid && confTracker > nnClassifier.thetaFP) { 262 valid = true; 263 } 264 } 265 } else if (numClusters == 1) { 266 currBB = detectorBB.clone(); 267 currConf = confDetector; 268 } 269 270 /* 271 * float var = CalculateVariance(patch.values, 272 * nn.patch_size*nn.patch_size); 273 * 274 * if(var < min_var) { //TODO: Think about incorporating this 275 * printf("%f, %f: Variance too low \n", var, classifier.min_var); valid 276 * = 0; } 277 */ 278 } 279 280 /** 281 * The initial learning is done using the input bounding box. 282 * 283 * Firstly, the {@link VarianceFilter} is told its 284 * {@link VarianceFilter#minVar} by finding the variance of the selected 285 * patch. 286 * 287 * Next all patches in {@link DetectorCascade} with a large offset (over 288 * 0.6f) with the selected box are used as positive examples while all 289 * windows with an overlap of less tha 0.2f and a variance greater than the 290 * minimum variance (i.e. they pass the variance check but yet do not 291 * overlap) are used as negative examples. The {@link EnsembleClassifier} is 292 * trained on the positive examples for {@link EnsembleClassifier}. 293 * 294 * Finally, the negative and positive examples are all fed to the 295 * {@link NNClassifier} using the {@link NNClassifier}. 296 * 297 * The usage of these 3 classifiers is explained in more detail in 298 * {@link DetectorCascade#detect(FImage)}. The {@link NNClassifier} is also 299 * used to calculate confidences in {@link #fuseHypotheses()} 300 */ 301 public void initialLearning() { 302 final int numWindows = detectorCascade.getNumWindows(); 303 learning = true; // This is just for display purposes 304 305 final DetectionResult detectionResult = detectorCascade.getDetectionResult(); 306 307 detectorCascade.detect(currImg); 308 309 // This is the positive patch 310 final NormalizedPatch patch = new NormalizedPatch(); 311 patch.source = currImg; 312 patch.window = currBB; 313 patch.positive = true; 314 315 final float initVar = patch.calculateVariance(); 316 detectorCascade.getVarianceFilter().minVar = initVar / 2; 317 318 final float[] overlap = new float[numWindows]; 319 detectorCascade.windowOverlap(currBB, overlap); 320 321 // Add all bounding boxes with high overlap 322 323 final List<IndependentPair<Integer, Float>> positiveIndices = new ArrayList<IndependentPair<Integer, Float>>(); 324 final List<Integer> negativeIndices = new ArrayList<Integer>(); 325 326 // First: Find overlapping positive and negative patches 327 for (int i = 0; i < numWindows; i++) { 328 329 if (overlap[i] > 0.6) { 330 positiveIndices.add(IndependentPair.pair(i, overlap[i])); 331 } 332 333 if (overlap[i] < 0.2) { 334 final float variance = detectionResult.variances[i]; 335 336 if (!detectorCascade.getVarianceFilter().enabled 337 || variance > detectorCascade.getVarianceFilter().minVar) 338 { // TODO: 339 // This 340 // check 341 // is 342 // unnecessary 343 // if 344 // minVar 345 // would 346 // be 347 // set 348 // before 349 // calling 350 // detect. 351 negativeIndices.add(i); 352 } 353 } 354 } 355 System.out.println("Number of positive features on init: " + positiveIndices.size()); 356 357 // This might be absolutely and horribly SLOW. figure it out. 358 Collections.sort(positiveIndices, 359 new Comparator<IndependentPair<Integer, Float>>() { 360 @Override 361 public int compare(IndependentPair<Integer, Float> o1, 362 IndependentPair<Integer, Float> o2) 363 { 364 return o1.secondObject().compareTo(o2.secondObject()); 365 } 366 367 }); 368 369 final List<NormalizedPatch> patches = new ArrayList<NormalizedPatch>(); 370 371 patches.add(patch); // Add first patch to patch list 372 373 final int numIterations = Math.min(positiveIndices.size(), 10); // Take 374 // at 375 // most 10 376 // bounding 377 // boxes 378 // (sorted 379 // by 380 // overlap) 381 for (int i = 0; i < numIterations; i++) { 382 final int idx = positiveIndices.get(i).firstObject(); 383 // Learn this bounding box 384 // TODO: Somewhere here image warping might be possible 385 detectorCascade.getEnsembleClassifier().learn( 386 currImg, true, 387 detectionResult.featureVectors, detectorCascade.numTrees * idx 388 ); 389 } 390 391 // be WARY. the random indecies are not actually random. maybe this 392 // doesn't matter. 393 final Random r = new Random(1); // TODO: This is not guaranteed to 394 // affect 395 // random_shuffle 396 397 // random_shuffle(negativeIndices.begin(), negativeIndices.end()); 398 Collections.shuffle(negativeIndices, r); 399 400 // Choose 100 random patches for negative examples 401 for (int i = 0; i < Math.min(100, negativeIndices.size()); i++) { 402 final int idx = negativeIndices.get(i); 403 404 final NormalizedPatch negPatch = new NormalizedPatch(); 405 negPatch.source = currImg; 406 negPatch.window = detectorCascade.getWindow(idx); 407 negPatch.prepareNormalisedPatch(); // This creates and sets the 408 // public valueImg which holds 409 // the normalised zoomed window 410 negPatch.positive = false; 411 patches.add(negPatch); 412 } 413 414 detectorCascade.getNNClassifier().learn(patches); 415 416 } 417 418 /** 419 * If the detection results are good and {@link #fuseHypotheses()} believes 420 * that the area was tracked to, but was not detected well then there is 421 * potential that the classifiers should be updated with the bounding box. 422 * 423 * The bounding box is used to extract highly overlapping windows as 424 * positive examples, and two kinds of negative examples are collected if 425 * they overlap less than 0.2f. For the ensemble classifier, negative 426 * examples are collected if the results of the {@link DetectorCascade} 427 */ 428 public void learn() { 429 if (!learningEnabled || !valid || !detectorEnabled) { 430 learning = false; 431 return; 432 } 433 final int numWindows = detectorCascade.getNumWindows(); 434 learning = true; 435 // 436 final DetectionResult detectionResult = detectorCascade.getDetectionResult(); 437 // 438 if (!detectionResult.containsValidData) { 439 detectorCascade.detect(currImg); 440 } 441 // 442 // This is the positive patch 443 NormalizedPatch patch = new NormalizedPatch(); 444 patch.source = currImg; 445 patch.window = currBB; 446 patch.prepareNormalisedPatch(); 447 // 448 final float[] overlap = new float[numWindows]; 449 this.detectorCascade.windowOverlap(currBB, overlap); 450 // 451 // //Add all bounding boxes with high overlap 452 // 453 final List<IndependentPair<Integer, Float>> positiveIndices = new ArrayList<IndependentPair<Integer, Float>>(); 454 final List<Integer> negativeIndices = new ArrayList<Integer>(); 455 final List<Integer> negativeIndicesForNN = new ArrayList<Integer>(); 456 // vector<pair<int,float> > positiveIndices; 457 // vector<int> negativeIndices; 458 // vector<int> negativeIndicesForNN; 459 // 460 // //First: Find overlapping positive and negative patches 461 // 462 for (int i = 0; i < numWindows; i++) { 463 // 464 if (overlap[i] > 0.6) { 465 positiveIndices.add(IndependentPair.pair(i, overlap[i])); 466 } 467 // 468 if (overlap[i] < 0.2) { 469 if (!detectorCascade.getEnsembleClassifier().enabled || detectionResult.posteriors[i] > 0.1) { // TODO: 470 // Shouldn't 471 // this 472 // read 473 // as 474 // 0.5? 475 negativeIndices.add(i); 476 } 477 478 if (!detectorCascade.getEnsembleClassifier().enabled || detectionResult.posteriors[i] > 0.5) { 479 negativeIndicesForNN.add(i); 480 } 481 482 } 483 } 484 485 Collections.sort(positiveIndices, 486 new Comparator<IndependentPair<Integer, Float>>() { 487 @Override 488 public int compare(IndependentPair<Integer, Float> o1, 489 IndependentPair<Integer, Float> o2) 490 { 491 return o1.secondObject().compareTo(o2.secondObject()); 492 } 493 494 }); 495 // 496 final List<NormalizedPatch> patches = new ArrayList<NormalizedPatch>(); 497 // 498 patch.positive = true; 499 patches.add(patch); 500 // //TODO: Flip 501 // 502 // 503 final int numIterations = Math.min(positiveIndices.size(), 10); // Take 504 // at 505 // most 506 // 10 507 // bounding 508 // boxes 509 // (sorted 510 // by 511 // overlap) 512 // 513 for (int i = 0; i < negativeIndices.size(); i++) { 514 final int idx = negativeIndices.get(i); 515 // TODO: Somewhere here image warping might be possible 516 detectorCascade.getEnsembleClassifier().learn(currImg, false, detectionResult.featureVectors, 517 detectorCascade.numTrees * idx); 518 // detectorCascade.ensembleClassifier.learn(currImg, 519 // detectorCascade.windows[idx], false, 520 // detectionResult.featureVectors[detectorCascade.numTrees*idx]); 521 } 522 // 523 // //TODO: Randomization might be a good idea 524 for (int i = 0; i < numIterations; i++) { 525 final int idx = positiveIndices.get(i).firstObject(); 526 // //TODO: Somewhere here image warping might be possible 527 // detectorCascade.ensembleClassifier.learn(currImg, 528 // &detectorCascade.windows[TLD_WINDOW_SIZE*idx], true, 529 // &detectionResult.featureVectors[detectorCascade.numTrees*idx]); 530 detectorCascade.getEnsembleClassifier().learn(currImg, true, detectionResult.featureVectors, 531 detectorCascade.numTrees * idx); 532 } 533 // 534 for (int i = 0; i < negativeIndicesForNN.size(); i++) { 535 final int idx = negativeIndicesForNN.get(i); 536 // 537 patch = new NormalizedPatch(); 538 patch.source = currImg; 539 patch.window = detectorCascade.getWindow(idx); 540 patch.positive = false; 541 patches.add(patch); 542 } 543 // 544 detectorCascade.getNNClassifier().learn(patches); 545 // 546 // //cout << "NN has now " << 547 // detectorCascade.nnClassifier.truePositives.size() << 548 // " positives and " << 549 // detectorCascade.nnClassifier.falsePositives.size() << 550 // " negatives.\n"; 551 // 552 // delete[] overlap; 553 } 554 555 /** 556 * @return whether the tracker is learning from the previous frame 557 */ 558 public boolean isLearning() { 559 return this.learning; 560 } 561}