001package org.openimaj.demos.sandbox.tldcpp;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.Comparator;
006import java.util.List;
007import java.util.Random;
008
009import org.openimaj.demos.sandbox.tldcpp.detector.Clustering;
010import org.openimaj.demos.sandbox.tldcpp.detector.DetectionResult;
011import org.openimaj.demos.sandbox.tldcpp.detector.DetectorCascade;
012import org.openimaj.demos.sandbox.tldcpp.detector.EnsembleClassifier;
013import org.openimaj.demos.sandbox.tldcpp.detector.NNClassifier;
014import org.openimaj.demos.sandbox.tldcpp.detector.NormalizedPatch;
015import org.openimaj.demos.sandbox.tldcpp.detector.VarianceFilter;
016import org.openimaj.demos.sandbox.tldcpp.tracker.MedianFlowTracker;
017import org.openimaj.demos.sandbox.tldcpp.videotld.TLDUtil;
018import org.openimaj.image.FImage;
019import org.openimaj.math.geometry.shape.Rectangle;
020import org.openimaj.util.pair.IndependentPair;
021
022/**
023 * An implementation TLD tracker by Zdenek Kalal:
024 * http://info.ee.surrey.ac.uk/Personal/Z.Kalal/tld.html based on the C++
025 * implementation Georg Nebehay: http://gnebehay.github.com/OpenTLD/
026 *
027 * This class is the main controller class. TLD is instantiated on an image and
028 * bounding box. Once the detector classifiers are initialised the
029 * {@link TLD#processImage(FImage)} function must be called with suceutive
030 * frames in which objects are: - Tracked using {@link MedianFlowTracker} -
031 * ...and if not tracked correctly detected using the {@link DetectorCascade}. -
032 * ... if tracked or detected correctly, but the object is different enough, it
033 * is learnt using {@link DetectorCascade}!
034 * 
035 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
036 *
037 */
038public class TLD {
039        /**
040         * Whether the {@link MedianFlowTracker} is enabled
041         */
042        public boolean trackerEnabled;
043        /**
044         * Whether the {@link DetectorCascade} is enabled
045         */
046        public boolean detectorEnabled;
047        /**
048         * Whether previously unseen frames are learnt from
049         */
050        public boolean learningEnabled;
051        /**
052         * Whether some frames are skipped
053         */
054        public boolean alternating;
055        /**
056         * The current bounding box
057         */
058        public Rectangle currBB;
059        /**
060         * the previous bounding box
061         */
062        public Rectangle prevBB;
063
064        /**
065         * The detector
066         */
067        public DetectorCascade detectorCascade;
068
069        /**
070         * The tracker
071         */
072        public MedianFlowTracker medianFlowTracker;
073        /**
074         * The previous frame, where #prevBB was detected
075         */
076        public FImage prevImg;
077        /**
078         * The current frame, where #currBB will be detected
079         */
080        public FImage currImg;
081        /**
082         * The confidence of the current bounding box. Calculated from the detector
083         * or the tracker. The {@link MedianFlowTracker#trackerBB} is extracted from
084         * the {@link #currImg} and confidence is gauged using {@link NNClassifier}
085         */
086        public float currConf;
087
088        /**
089         * The nearest neighbour classifier
090         */
091        private NNClassifier nnClassifier;
092        private boolean learning;
093        private boolean valid;
094        private boolean wasValid;
095        private int imgWidth;
096        private int imgHeight;
097
098        /**
099         * Initialises the TLD with a series of defaults. {@link #trackerEnabled} is
100         * true {@link #detectorEnabled} is true {@link #learningEnabled} is true
101         * {@link #alternating} is false {@link #alternating} is false valid
102         */
103        private TLD() {
104                trackerEnabled = true;
105                detectorEnabled = true;
106                learningEnabled = true;
107                alternating = false;
108                valid = false;
109                wasValid = false;
110                learning = false;
111                currBB = null;
112
113                detectorCascade = new DetectorCascade();
114                nnClassifier = detectorCascade.getNNClassifier();
115
116                medianFlowTracker = new MedianFlowTracker();
117        }
118
119        /**
120         * @param width
121         * @param height
122         */
123        public TLD(int width, int height) {
124                this();
125                this.imgWidth = width;
126                this.imgHeight = height;
127        }
128
129        /**
130         * Stop tracking whatever is currently being tracked
131         */
132        public void release() {
133                detectorCascade.release();
134                medianFlowTracker.cleanPreviousData();
135                currBB = null;
136        }
137
138        //
139        private void storeCurrentData() {
140                prevImg = null;
141                prevImg = currImg; // Store old image (if any)
142                prevBB = currBB; // Store old bounding box (if any)
143
144                detectorCascade.cleanPreviousData(); // Reset detector results
145                medianFlowTracker.cleanPreviousData();
146
147                wasValid = valid;
148        }
149
150        /**
151         * Set the current object being tracked. Initialilise the detector casecade
152         * using {@link DetectorCascade#init()}. The {@link #initialLearning()} is
153         * called
154         *
155         * @param img
156         * @param bb
157         * @throws Exception
158         */
159        public void selectObject(FImage img, Rectangle bb) throws Exception {
160                // Delete old object
161                detectorCascade.release();
162
163                detectorCascade.setObjWidth((int) bb.width);
164                detectorCascade.setObjHeight((int) bb.height);
165                detectorCascade.setImgWidth(this.imgWidth);
166                detectorCascade.setImgHeight(this.imgHeight);
167
168                // Init detector cascade
169                detectorCascade.init();
170
171                currImg = img;
172                currBB = bb;
173                currConf = 1;
174                valid = true;
175
176                initialLearning();
177
178        }
179
180        /**
181         * An attempt is made to track the object from the previous frame. The
182         * {@link DetectorCascade} instance is used regardless to detect the object.
183         * The {@link #fuseHypotheses()} is then used to combine the estimate of the
184         * tracker and detector together. Finally, using the detectorCascade the
185         * learn function is called and the classifier is improved
186         * 
187         * @param img
188         */
189        public void processImage(FImage img) {
190                storeCurrentData();
191                final FImage grey_frame = img.clone(); // Store new image , right after
192                                                                                                // storeCurrentData();
193                currImg = grey_frame;
194                //
195                if (trackerEnabled) {
196                        medianFlowTracker.track(prevImg, currImg, prevBB);
197                }
198                //
199                if (detectorEnabled && (!alternating || medianFlowTracker.trackerBB == null)) {
200                        detectorCascade.detect(grey_frame);
201                }
202                //
203                fuseHypotheses();
204                //
205                learn();
206
207                // if(!valid){
208                // currBB = null;
209                // currConf = 0;
210                // }
211                //
212        }
213
214        /**
215         * The bounding box is retrieved from the tracker and detector as well as
216         * the number of clusters detected by the detector {@link Clustering} step.
217         * If exactly one cluster exists in {@link Clustering} (i.e. the detector is
218         * very sure) the detector confidence is included. If the tracker was able
219         * to keep track of the bounding box (i.e. trackerBB is not null) then the
220         * tracker confidence is combined.
221         *
222         * if the detector is more confident than the tracker and their overlap is
223         * very small, the detectors BB is used. Otherwise the trackers BB and
224         * confidence is used. If the trackerBB is used the tracking is valid only
225         * if the tracking was invalid last time and the confidence is above
226         * {@link NNClassifier#thetaTP} or if the tracking was valid last time a
227         * smaller threshold of {@link NNClassifier#thetaFP} is used.
228         *
229         * TODO: Maybe a better combination of the two bounding boxes from the
230         * detector and tracker would be better?
231         */
232        public void fuseHypotheses() {
233                final Rectangle trackerBB = medianFlowTracker.trackerBB;
234                final int numClusters = detectorCascade.getDetectionResult().numClusters;
235                final Rectangle detectorBB = detectorCascade.getDetectionResult().detectorBB;
236
237                currBB = null;
238                currConf = 0;
239                valid = false;
240
241                float confDetector = 0;
242
243                if (numClusters == 1) {
244                        confDetector = nnClassifier.classifyBB(currImg, detectorBB);
245                }
246
247                if (trackerBB != null) {
248                        final float confTracker = nnClassifier.classifyBB(currImg, trackerBB);
249
250                        if (numClusters == 1 && confDetector > confTracker
251                                        && TLDUtil.tldOverlapNorm(trackerBB, detectorBB) < 0.5)
252                        {
253
254                                currBB = detectorBB.clone();
255                                currConf = confDetector;
256                        } else {
257                                currBB = trackerBB.clone();
258                                currConf = confTracker;
259                                if (confTracker > nnClassifier.thetaTP) {
260                                        valid = true;
261                                } else if (wasValid && confTracker > nnClassifier.thetaFP) {
262                                        valid = true;
263                                }
264                        }
265                } else if (numClusters == 1) {
266                        currBB = detectorBB.clone();
267                        currConf = confDetector;
268                }
269
270                /*
271                 * float var = CalculateVariance(patch.values,
272                 * nn.patch_size*nn.patch_size);
273                 *
274                 * if(var < min_var) { //TODO: Think about incorporating this
275                 * printf("%f, %f: Variance too low \n", var, classifier.min_var); valid
276                 * = 0; }
277                 */
278        }
279
280        /**
281         * The initial learning is done using the input bounding box.
282         *
283         * Firstly, the {@link VarianceFilter} is told its
284         * {@link VarianceFilter#minVar} by finding the variance of the selected
285         * patch.
286         *
287         * Next all patches in {@link DetectorCascade} with a large offset (over
288         * 0.6f) with the selected box are used as positive examples while all
289         * windows with an overlap of less tha 0.2f and a variance greater than the
290         * minimum variance (i.e. they pass the variance check but yet do not
291         * overlap) are used as negative examples. The {@link EnsembleClassifier} is
292         * trained on the positive examples for {@link EnsembleClassifier}.
293         *
294         * Finally, the negative and positive examples are all fed to the
295         * {@link NNClassifier} using the {@link NNClassifier}.
296         *
297         * The usage of these 3 classifiers is explained in more detail in
298         * {@link DetectorCascade#detect(FImage)}. The {@link NNClassifier} is also
299         * used to calculate confidences in {@link #fuseHypotheses()}
300         */
301        public void initialLearning() {
302                final int numWindows = detectorCascade.getNumWindows();
303                learning = true; // This is just for display purposes
304
305                final DetectionResult detectionResult = detectorCascade.getDetectionResult();
306
307                detectorCascade.detect(currImg);
308
309                // This is the positive patch
310                final NormalizedPatch patch = new NormalizedPatch();
311                patch.source = currImg;
312                patch.window = currBB;
313                patch.positive = true;
314
315                final float initVar = patch.calculateVariance();
316                detectorCascade.getVarianceFilter().minVar = initVar / 2;
317
318                final float[] overlap = new float[numWindows];
319                detectorCascade.windowOverlap(currBB, overlap);
320
321                // Add all bounding boxes with high overlap
322
323                final List<IndependentPair<Integer, Float>> positiveIndices = new ArrayList<IndependentPair<Integer, Float>>();
324                final List<Integer> negativeIndices = new ArrayList<Integer>();
325
326                // First: Find overlapping positive and negative patches
327                for (int i = 0; i < numWindows; i++) {
328
329                        if (overlap[i] > 0.6) {
330                                positiveIndices.add(IndependentPair.pair(i, overlap[i]));
331                        }
332
333                        if (overlap[i] < 0.2) {
334                                final float variance = detectionResult.variances[i];
335
336                                if (!detectorCascade.getVarianceFilter().enabled
337                                                || variance > detectorCascade.getVarianceFilter().minVar)
338                                { // TODO:
339                                        // This
340                                        // check
341                                        // is
342                                        // unnecessary
343                                        // if
344                                        // minVar
345                                        // would
346                                        // be
347                                        // set
348                                        // before
349                                        // calling
350                                        // detect.
351                                        negativeIndices.add(i);
352                                }
353                        }
354                }
355                System.out.println("Number of positive features on init: " + positiveIndices.size());
356
357                // This might be absolutely and horribly SLOW. figure it out.
358                Collections.sort(positiveIndices,
359                                new Comparator<IndependentPair<Integer, Float>>() {
360                        @Override
361                        public int compare(IndependentPair<Integer, Float> o1,
362                                        IndependentPair<Integer, Float> o2)
363                                        {
364                                return o1.secondObject().compareTo(o2.secondObject());
365                        }
366
367                });
368
369                final List<NormalizedPatch> patches = new ArrayList<NormalizedPatch>();
370
371                patches.add(patch); // Add first patch to patch list
372
373                final int numIterations = Math.min(positiveIndices.size(), 10); // Take
374                                                                                                                                                // at
375                // most 10
376                // bounding
377                // boxes
378                // (sorted
379                // by
380                // overlap)
381                for (int i = 0; i < numIterations; i++) {
382                        final int idx = positiveIndices.get(i).firstObject();
383                        // Learn this bounding box
384                        // TODO: Somewhere here image warping might be possible
385                        detectorCascade.getEnsembleClassifier().learn(
386                                        currImg, true,
387                                        detectionResult.featureVectors, detectorCascade.numTrees * idx
388                                        );
389                }
390
391                // be WARY. the random indecies are not actually random. maybe this
392                // doesn't matter.
393                final Random r = new Random(1); // TODO: This is not guaranteed to
394                                                                                // affect
395                // random_shuffle
396
397                // random_shuffle(negativeIndices.begin(), negativeIndices.end());
398                Collections.shuffle(negativeIndices, r);
399
400                // Choose 100 random patches for negative examples
401                for (int i = 0; i < Math.min(100, negativeIndices.size()); i++) {
402                        final int idx = negativeIndices.get(i);
403
404                        final NormalizedPatch negPatch = new NormalizedPatch();
405                        negPatch.source = currImg;
406                        negPatch.window = detectorCascade.getWindow(idx);
407                        negPatch.prepareNormalisedPatch(); // This creates and sets the
408                                                                                                // public valueImg which holds
409                                                                                                // the normalised zoomed window
410                        negPatch.positive = false;
411                        patches.add(negPatch);
412                }
413
414                detectorCascade.getNNClassifier().learn(patches);
415
416        }
417
418        /**
419         * If the detection results are good and {@link #fuseHypotheses()} believes
420         * that the area was tracked to, but was not detected well then there is
421         * potential that the classifiers should be updated with the bounding box.
422         *
423         * The bounding box is used to extract highly overlapping windows as
424         * positive examples, and two kinds of negative examples are collected if
425         * they overlap less than 0.2f. For the ensemble classifier, negative
426         * examples are collected if the results of the {@link DetectorCascade}
427         */
428        public void learn() {
429                if (!learningEnabled || !valid || !detectorEnabled) {
430                        learning = false;
431                        return;
432                }
433                final int numWindows = detectorCascade.getNumWindows();
434                learning = true;
435                //
436                final DetectionResult detectionResult = detectorCascade.getDetectionResult();
437                //
438                if (!detectionResult.containsValidData) {
439                        detectorCascade.detect(currImg);
440                }
441                //
442                // This is the positive patch
443                NormalizedPatch patch = new NormalizedPatch();
444                patch.source = currImg;
445                patch.window = currBB;
446                patch.prepareNormalisedPatch();
447                //
448                final float[] overlap = new float[numWindows];
449                this.detectorCascade.windowOverlap(currBB, overlap);
450                //
451                // //Add all bounding boxes with high overlap
452                //
453                final List<IndependentPair<Integer, Float>> positiveIndices = new ArrayList<IndependentPair<Integer, Float>>();
454                final List<Integer> negativeIndices = new ArrayList<Integer>();
455                final List<Integer> negativeIndicesForNN = new ArrayList<Integer>();
456                // vector<pair<int,float> > positiveIndices;
457                // vector<int> negativeIndices;
458                // vector<int> negativeIndicesForNN;
459                //
460                // //First: Find overlapping positive and negative patches
461                //
462                for (int i = 0; i < numWindows; i++) {
463                        //
464                        if (overlap[i] > 0.6) {
465                                positiveIndices.add(IndependentPair.pair(i, overlap[i]));
466                        }
467                        //
468                        if (overlap[i] < 0.2) {
469                                if (!detectorCascade.getEnsembleClassifier().enabled || detectionResult.posteriors[i] > 0.1) { // TODO:
470                                                                                                                                                                                                                                // Shouldn't
471                                                                                                                                                                                                                                // this
472                                                                                                                                                                                                                                // read
473                                                                                                                                                                                                                                // as
474                                                                                                                                                                                                                                // 0.5?
475                                        negativeIndices.add(i);
476                                }
477
478                                if (!detectorCascade.getEnsembleClassifier().enabled || detectionResult.posteriors[i] > 0.5) {
479                                        negativeIndicesForNN.add(i);
480                                }
481
482                        }
483                }
484
485                Collections.sort(positiveIndices,
486                                new Comparator<IndependentPair<Integer, Float>>() {
487                        @Override
488                        public int compare(IndependentPair<Integer, Float> o1,
489                                        IndependentPair<Integer, Float> o2)
490                                        {
491                                return o1.secondObject().compareTo(o2.secondObject());
492                        }
493
494                });
495                //
496                final List<NormalizedPatch> patches = new ArrayList<NormalizedPatch>();
497                //
498                patch.positive = true;
499                patches.add(patch);
500                // //TODO: Flip
501                //
502                //
503                final int numIterations = Math.min(positiveIndices.size(), 10); // Take
504                                                                                                                                                // at
505                                                                                                                                                // most
506                                                                                                                                                // 10
507                                                                                                                                                // bounding
508                                                                                                                                                // boxes
509                                                                                                                                                // (sorted
510                                                                                                                                                // by
511                                                                                                                                                // overlap)
512                //
513                for (int i = 0; i < negativeIndices.size(); i++) {
514                        final int idx = negativeIndices.get(i);
515                        // TODO: Somewhere here image warping might be possible
516                        detectorCascade.getEnsembleClassifier().learn(currImg, false, detectionResult.featureVectors,
517                                        detectorCascade.numTrees * idx);
518                        // detectorCascade.ensembleClassifier.learn(currImg,
519                        // detectorCascade.windows[idx], false,
520                        // detectionResult.featureVectors[detectorCascade.numTrees*idx]);
521                }
522                //
523                // //TODO: Randomization might be a good idea
524                for (int i = 0; i < numIterations; i++) {
525                        final int idx = positiveIndices.get(i).firstObject();
526                        // //TODO: Somewhere here image warping might be possible
527                        // detectorCascade.ensembleClassifier.learn(currImg,
528                        // &detectorCascade.windows[TLD_WINDOW_SIZE*idx], true,
529                        // &detectionResult.featureVectors[detectorCascade.numTrees*idx]);
530                        detectorCascade.getEnsembleClassifier().learn(currImg, true, detectionResult.featureVectors,
531                                        detectorCascade.numTrees * idx);
532                }
533                //
534                for (int i = 0; i < negativeIndicesForNN.size(); i++) {
535                        final int idx = negativeIndicesForNN.get(i);
536                        //
537                        patch = new NormalizedPatch();
538                        patch.source = currImg;
539                        patch.window = detectorCascade.getWindow(idx);
540                        patch.positive = false;
541                        patches.add(patch);
542                }
543                //
544                detectorCascade.getNNClassifier().learn(patches);
545                //
546                // //cout << "NN has now " <<
547                // detectorCascade.nnClassifier.truePositives.size() <<
548                // " positives and " <<
549                // detectorCascade.nnClassifier.falsePositives.size() <<
550                // " negatives.\n";
551                //
552                // delete[] overlap;
553        }
554
555        /**
556         * @return whether the tracker is learning from the previous frame
557         */
558        public boolean isLearning() {
559                return this.learning;
560        }
561}