001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.model.fit;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.math.model.EstimatableModel;
036import org.openimaj.math.model.fit.residuals.ResidualCalculator;
037import org.openimaj.math.util.distance.DistanceCheck;
038import org.openimaj.math.util.distance.ThresholdDistanceCheck;
039import org.openimaj.util.CollectionSampler;
040import org.openimaj.util.UniformSampler;
041import org.openimaj.util.pair.IndependentPair;
042
043/**
044 * The RANSAC Algorithm (RANdom SAmple Consensus)
045 * <p>
046 * For fitting noisy data consisting of inliers and outliers to a model.
047 * </p>
048 * <p>
049 * Assume: M data items required to estimate parameter x N data items in total
050 * </p>
051 * <p>
052 * 1.) select M data items at random <br>
053 * </br>
054 * 2.) estimate parameter x <br>
055 * </br>
056 * 3.) find how many of the N data items fit (i.e. have an error less than a
057 * threshold or pass some check) x within tolerence tol, call this K <br>
058 * </br>
059 * 4.) if K is large enough (bigger than numItems) accept x and exit with
060 * success <br>
061 * </br>
062 * 5.) repeat 1..4 nIter times <br>
063 * </br>
064 * 6.) fail - no good x fit of data
065 * </p>
066 * <p>
067 * In this implementation, the conditions that control the iterations are
068 * configurable. In addition, the best matching model is always stored, even if
069 * the fitData() method returns false.
070 *
071 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
072 *
073 * @param <I>
074 *            type of independent data
075 * @param <D>
076 *            type of dependent data
077 * @param <M>
078 *            concrete type of model learned
079 */
080public class RANSAC<I, D, M extends EstimatableModel<I, D>> implements RobustModelFitting<I, D, M> {
081        /**
082         * Interface for classes that can control RANSAC iterations
083         */
084        public static interface StoppingCondition {
085                /**
086                 * Initialise the stopping condition if necessary. Return false if the
087                 * initialisation cannot be performed and RANSAC should fail
088                 *
089                 * @param data
090                 *            The data being fitted
091                 * @param model
092                 *            The model to fit
093                 * @return true if initialisation is successful, false otherwise.
094                 */
095                public abstract boolean init(final List<?> data, EstimatableModel<?, ?> model);
096
097                /**
098                 * Should we stop iterating and return the model?
099                 *
100                 * @param numInliers
101                 *            number of inliers in this iteration
102                 * @return true if the model is good and iterations should stop
103                 */
104                public abstract boolean shouldStopIterations(int numInliers);
105
106                /**
107                 * Should the model be considered to fit after the final iteration has
108                 * passed?
109                 *
110                 * @param numInliers
111                 *            number of inliers in the final model
112                 * @return true if the model fits, false otherwise
113                 */
114                public abstract boolean finalFitCondition(int numInliers);
115        }
116
117        /**
118         * Stopping condition that tests the number of matches against a threshold.
119         * If the number exceeds the threshold, then the model is considered to fit.
120         */
121        public static class NumberInliersStoppingCondition implements StoppingCondition {
122                int limit;
123
124                /**
125                 * Construct the stopping condition with the given threshold on the
126                 * number of data points which must match for a model to be considered a
127                 * fit.
128                 *
129                 * @param limit
130                 *            the threshold
131                 */
132                public NumberInliersStoppingCondition(int limit) {
133                        this.limit = limit;
134                }
135
136                @Override
137                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
138                        if (limit < model.numItemsToEstimate()) {
139                                limit = model.numItemsToEstimate();
140                        }
141
142                        if (data.size() < limit)
143                                return false;
144                        return true;
145                }
146
147                @Override
148                public boolean shouldStopIterations(int numInliers) {
149                        return numInliers >= limit; // stop if there are more inliers than
150                        // our limit
151                }
152
153                @Override
154                public boolean finalFitCondition(int numInliers) {
155                        return numInliers >= limit;
156                }
157        }
158
159        /**
160         * Stopping condition that tests the number of matches against a percentage
161         * threshold of the whole data. If the number exceeds the threshold, then
162         * the model is considered to fit.
163         */
164        public static class PercentageInliersStoppingCondition extends NumberInliersStoppingCondition {
165                double percentageLimit;
166
167                /**
168                 * Construct the stopping condition with the given percentage threshold
169                 * on the number of data points which must match for a model to be
170                 * considered a fit.
171                 *
172                 * @param percentageLimit
173                 *            the percentage threshold
174                 */
175                public PercentageInliersStoppingCondition(double percentageLimit) {
176                        super(0);
177                        this.percentageLimit = percentageLimit;
178                }
179
180                @Override
181                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
182                        this.limit = (int) Math.rint(percentageLimit * data.size());
183                        return super.init(data, model);
184                }
185        }
186
187        /**
188         * Stopping condition that tests the number of matches against a percentage
189         * threshold of the whole data. If the number exceeds the threshold, then
190         * the model is considered to fit.
191         */
192        public static class ProbabilisticMinInliersStoppingCondition implements StoppingCondition {
193                private static final double DEFAULT_INLIER_IS_BAD_PROBABILITY = 0.1;
194                private static final double DEFAULT_PERCENTAGE_INLIERS = 0.25;
195                private double inlierIsBadProbability;
196                private double desiredErrorProbability;
197                private double percentageInliers;
198
199                private int numItemsToEstimate;
200                private int iteration = 0;
201                private int limit;
202                private int maxInliers = 0;
203                private double currentProb;
204                private int numDataItems;
205
206                /**
207                 * Default constructor.
208                 *
209                 * @param desiredErrorProbability
210                 *            The desired error rate
211                 * @param inlierIsBadProbability
212                 *            The probability an inlier is bad
213                 * @param percentageInliers
214                 *            The percentage of inliers in the data
215                 */
216                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability, double inlierIsBadProbability,
217                                double percentageInliers)
218                {
219                        this.desiredErrorProbability = desiredErrorProbability;
220                        this.inlierIsBadProbability = inlierIsBadProbability;
221                        this.percentageInliers = percentageInliers;
222                }
223
224                /**
225                 * Constructor with defaults for bad inlier probability and percentage
226                 * inliers.
227                 *
228                 * @param desiredErrorProbability
229                 *            The desired error rate
230                 */
231                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability) {
232                        this(desiredErrorProbability, DEFAULT_INLIER_IS_BAD_PROBABILITY, DEFAULT_PERCENTAGE_INLIERS);
233                }
234
235                @Override
236                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
237                        numItemsToEstimate = model.numItemsToEstimate();
238                        numDataItems = data.size();
239                        this.limit = calculateMinInliers();
240                        this.iteration = 0;
241                        this.currentProb = 1.0;
242                        this.maxInliers = 0;
243
244                        return true;
245                }
246
247                @Override
248                public boolean finalFitCondition(int numInliers) {
249                        return numInliers >= limit;
250                }
251
252                private int calculateMinInliers() {
253                        double pi, sum;
254                        int i, j;
255
256                        for (j = numItemsToEstimate + 1; j <= numDataItems; j++) {
257                                sum = 0;
258                                for (i = j; i <= numDataItems; i++) {
259                                        pi = (i - numItemsToEstimate) * Math.log(inlierIsBadProbability)
260                                                        + (numDataItems - i + numItemsToEstimate) * Math.log(1.0 - inlierIsBadProbability) +
261                                                        log_factorial(numDataItems - numItemsToEstimate) - log_factorial(i - numItemsToEstimate)
262                                                        - log_factorial(numDataItems - i);
263                                        /*
264                                         * Last three terms above are equivalent to log( n-m choose
265                                         * i-m )
266                                         */
267                                        sum += Math.exp(pi);
268                                }
269                                if (sum < desiredErrorProbability)
270                                        break;
271                        }
272                        return j;
273                }
274
275                private double log_factorial(int n) {
276                        double f = 0;
277                        int i;
278
279                        for (i = 1; i <= n; i++)
280                                f += Math.log(i);
281
282                        return f;
283                }
284
285                @Override
286                public boolean shouldStopIterations(int numInliers) {
287
288                        if (numInliers > maxInliers) {
289                                maxInliers = numInliers;
290                                percentageInliers = (double) maxInliers / numDataItems;
291
292                                // System.err.format("Updated maxInliers: %d\n", maxInliers);
293                        }
294                        currentProb = Math.pow(1.0 - Math.pow(percentageInliers, numItemsToEstimate), ++iteration);
295                        return currentProb <= this.desiredErrorProbability;
296                }
297        }
298
299        /**
300         * Stopping condition that allows the RANSAC algorithm to run until all the
301         * iterations have been exhausted. The fitData method will return true if
302         * there are at least as many inliers as datapoints required to estimate the
303         * model, and the model will be the one from the iteration that had the most
304         * inliers.
305         */
306        public static class BestFitStoppingCondition implements StoppingCondition {
307                int required;
308
309                @Override
310                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
311                        required = model.numItemsToEstimate();
312                        return true;
313                }
314
315                @Override
316                public boolean shouldStopIterations(int numInliers) {
317                        return false; // just iterate until the end
318                }
319
320                @Override
321                public boolean finalFitCondition(int numInliers) {
322                        return numInliers > required; // accept the best result as a good
323                        // fit if there are enough inliers
324                }
325        }
326
327        protected M model;
328        protected ResidualCalculator<I, D, M> errorModel;
329        protected DistanceCheck dc;
330
331        protected int nIter;
332        protected boolean improveEstimate;
333        protected List<IndependentPair<I, D>> inliers;
334        protected List<IndependentPair<I, D>> outliers;
335        protected List<IndependentPair<I, D>> bestModelInliers;
336        protected List<IndependentPair<I, D>> bestModelOutliers;
337        protected StoppingCondition stoppingCondition;
338        protected List<? extends IndependentPair<I, D>> modelConstructionData;
339        protected CollectionSampler<IndependentPair<I, D>> sampler;
340
341        /**
342         * Create a RANSAC object with uniform random sampling for creating the
343         * subsets
344         *
345         * @param model
346         *            Model object with which to fit data
347         * @param errorModel
348         *            object to compute the error of the model
349         * @param errorThreshold
350         *            the threshold below which error is deemed acceptable for a fit
351         * @param nIterations
352         *            Maximum number of allowed iterations (L)
353         * @param stoppingCondition
354         *            the stopping condition
355         * @param impEst
356         *            True if we want to perform a final fitting of the model with
357         *            all inliers, false otherwise
358         */
359        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
360                        double errorThreshold, int nIterations,
361                        StoppingCondition stoppingCondition, boolean impEst)
362        {
363                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst);
364        }
365
366        /**
367         * Create a RANSAC object with uniform random sampling for creating the
368         * subsets
369         *
370         * @param model
371         *            Model object with which to fit data
372         * @param errorModel
373         *            object to compute the error of the model
374         * @param dc
375         *            the distance check that tests whether a point with given error
376         *            from the error model should be considered an inlier
377         * @param nIterations
378         *            Maximum number of allowed iterations (L)
379         * @param stoppingCondition
380         *            the stopping condition
381         * @param impEst
382         *            True if we want to perform a final fitting of the model with
383         *            all inliers, false otherwise
384         */
385        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
386                        DistanceCheck dc, int nIterations,
387                        StoppingCondition stoppingCondition, boolean impEst)
388        {
389                this(model, errorModel, dc, nIterations, stoppingCondition, impEst, new UniformSampler<IndependentPair<I, D>>());
390        }
391
392        /**
393         * Create a RANSAC object
394         *
395         * @param model
396         *            Model object with which to fit data
397         * @param errorModel
398         *            object to compute the error of the model
399         * @param errorThreshold
400         *            the threshold below which error is deemed acceptable for a fit
401         * @param nIterations
402         *            Maximum number of allowed iterations (L)
403         * @param stoppingCondition
404         *            the stopping condition
405         * @param impEst
406         *            True if we want to perform a final fitting of the model with
407         *            all inliers, false otherwise
408         * @param sampler
409         *            the sampling algorithm for selecting random subsets
410         */
411        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
412                        double errorThreshold, int nIterations,
413                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
414        {
415                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst,
416                                sampler);
417        }
418
419        /**
420         * Create a RANSAC object
421         *
422         * @param model
423         *            Model object with which to fit data
424         * @param errorModel
425         *            object to compute the error of the model
426         * @param dc
427         *            the distance check that tests whether a point with given error
428         *            from the error model should be considered an inlier
429         * @param nIterations
430         *            Maximum number of allowed iterations (L)
431         * @param stoppingCondition
432         *            the stopping condition
433         * @param impEst
434         *            True if we want to perform a final fitting of the model with
435         *            all inliers, false otherwise
436         * @param sampler
437         *            the sampling algorithm for selecting random subsets
438         */
439        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
440                        DistanceCheck dc, int nIterations,
441                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
442        {
443                this.stoppingCondition = stoppingCondition;
444                this.model = model;
445                this.errorModel = errorModel;
446                this.dc = dc;
447                nIter = nIterations;
448                improveEstimate = impEst;
449
450                inliers = new ArrayList<IndependentPair<I, D>>();
451                outliers = new ArrayList<IndependentPair<I, D>>();
452                this.sampler = sampler;
453        }
454
455        @Override
456        public boolean fitData(final List<? extends IndependentPair<I, D>> data) {
457                int l;
458                final int M = model.numItemsToEstimate();
459
460                bestModelInliers = null;
461                bestModelOutliers = null;
462
463                if (data.size() < M || !stoppingCondition.init(data, model)) {
464                        return false; // there are not enough points to create a model, or
465                        // init failed
466                }
467
468                sampler.setCollection(data);
469
470                for (l = 0; l < nIter; l++) {
471                        // 1
472                        final List<? extends IndependentPair<I, D>> rnd = sampler.sample(M);
473                        this.setModelConstructionData(rnd);
474
475                        // 2
476                        if (!model.estimate(rnd))
477                                continue; // bad estimate
478
479                        errorModel.setModel(model);
480
481                        // 3
482                        int K = 0;
483                        inliers.clear();
484                        outliers.clear();
485                        for (final IndependentPair<I, D> dp : data) {
486                                if (dc.check(errorModel.computeResidual(dp))) {
487                                        K++;
488                                        inliers.add(dp);
489                                } else {
490                                        outliers.add(dp);
491                                }
492                        }
493
494                        if (bestModelInliers == null || inliers.size() >= bestModelInliers.size()) {
495                                // copy
496                                bestModelInliers = new ArrayList<IndependentPair<I, D>>(inliers);
497                                bestModelOutliers = new ArrayList<IndependentPair<I, D>>(outliers);
498                        }
499
500                        // 4
501                        if (stoppingCondition.shouldStopIterations(K)) {
502                                // generate "best" fit from all the iterations
503                                inliers = bestModelInliers;
504                                outliers = bestModelOutliers;
505
506                                if (improveEstimate) {
507                                        if (inliers.size() >= model.numItemsToEstimate())
508                                                if (!model.estimate(inliers))
509                                                        return false;
510                                }
511                                final boolean stopping = stoppingCondition.finalFitCondition(inliers.size());
512                                // System.err.format("done: %b\n",stopping);
513                                return stopping;
514                        }
515                        // 5
516                        // ...repeat...
517                }
518
519                // generate "best" fit from all the iterations
520                if (bestModelInliers == null) {
521                        bestModelInliers = new ArrayList<IndependentPair<I, D>>();
522                        bestModelOutliers = new ArrayList<IndependentPair<I, D>>();
523                }
524
525                inliers = bestModelInliers;
526                outliers = bestModelOutliers;
527
528                if (bestModelInliers.size() >= M)
529                        if (!model.estimate(bestModelInliers))
530                                return false;
531
532                // 6 - fail
533                return stoppingCondition.finalFitCondition(inliers.size());
534        }
535
536        @Override
537        public List<? extends IndependentPair<I, D>> getInliers() {
538                return inliers;
539        }
540
541        @Override
542        public List<? extends IndependentPair<I, D>> getOutliers() {
543                return outliers;
544        }
545
546        /**
547         * @return maximum number of allowed iterations
548         */
549        public int getMaxIterations() {
550                return nIter;
551        }
552
553        /**
554         * Set the maximum number of allowed iterations
555         *
556         * @param nIter
557         *            maximum number of allowed iterations
558         */
559        public void setMaxIterations(int nIter) {
560                this.nIter = nIter;
561        }
562
563        @Override
564        public M getModel() {
565                return model;
566        }
567
568        /**
569         * Set the underlying model being fitted
570         *
571         * @param model
572         *            the model
573         */
574        public void setModel(M model) {
575                this.model = model;
576        }
577
578        /**
579         * @return whether RANSAC should attempt to improve the model using all
580         *         inliers as data
581         */
582        public boolean isImproveEstimate() {
583                return improveEstimate;
584        }
585
586        /**
587         * Set whether RANSAC should attempt to improve the model using all inliers
588         * as data
589         *
590         * @param improveEstimate
591         *            should RANSAC attempt to improve the model using all inliers
592         *            as data
593         */
594        public void setImproveEstimate(boolean improveEstimate) {
595                this.improveEstimate = improveEstimate;
596        }
597
598        /**
599         * Set the data used to construct the model
600         *
601         * @param modelConstructionData
602         */
603        public void setModelConstructionData(List<? extends IndependentPair<I, D>> modelConstructionData) {
604                this.modelConstructionData = modelConstructionData;
605        }
606
607        /**
608         * @return The data used to construct the model.
609         */
610        public List<? extends IndependentPair<I, D>> getModelConstructionData() {
611                return modelConstructionData;
612        }
613
614        @Override
615        public int numItemsToEstimate() {
616                return model.numItemsToEstimate();
617        }
618}