001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.model;
031
032import gov.sandia.cognition.learning.algorithm.bayes.VectorNaiveBayesCategorizer;
033import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
034import gov.sandia.cognition.learning.data.InputOutputPair;
035import gov.sandia.cognition.math.matrix.Vector;
036import gov.sandia.cognition.math.matrix.VectorFactory;
037import gov.sandia.cognition.statistics.DataHistogram;
038import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
039import gov.sandia.cognition.statistics.distribution.UnivariateGaussian.PDF;
040
041import java.util.ArrayList;
042import java.util.HashMap;
043import java.util.List;
044import java.util.Map;
045
046import org.openimaj.util.pair.IndependentPair;
047
048/**
049 * An implementation of a {@link EstimatableModel} that uses a
050 * {@link VectorNaiveBayesCategorizer} to associate a univariate (a
051 * {@link Double}) with a category.
052 * 
053 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
054 * 
055 * @param <T>
056 *            The type of class/category predicted by the model
057 */
058
059public class UnivariateGaussianNaiveBayesModel<T> implements EstimatableModel<Double, T> {
060        private VectorNaiveBayesCategorizer<T, PDF> model;
061
062        /**
063         * Default constructor.
064         */
065        public UnivariateGaussianNaiveBayesModel() {
066
067        }
068
069        /**
070         * Construct with a pre-trained model.
071         * 
072         * @param model
073         *            the pre-trained model.
074         */
075        public UnivariateGaussianNaiveBayesModel(VectorNaiveBayesCategorizer<T, PDF> model) {
076                this.model = model;
077        }
078
079        @Override
080        public boolean estimate(List<? extends IndependentPair<Double, T>> data) {
081                final VectorNaiveBayesCategorizer.BatchGaussianLearner<T> learner = new VectorNaiveBayesCategorizer.BatchGaussianLearner<T>();
082                final List<InputOutputPair<Vector, T>> cfdata = new ArrayList<InputOutputPair<Vector, T>>();
083
084                for (final IndependentPair<Double, T> d : data) {
085                        final InputOutputPair<Vector, T> iop = new DefaultInputOutputPair<Vector, T>(VectorFactory.getDefault()
086                                        .createVector1D(d.firstObject()), d.secondObject());
087                        cfdata.add(iop);
088                }
089
090                model = learner.learn(cfdata);
091
092                return true;
093        }
094
095        @Override
096        public T predict(Double data) {
097                return model.evaluate(VectorFactory.getDefault().createVector1D(data));
098        }
099
100        @Override
101        public int numItemsToEstimate() {
102                return 0;
103        }
104
105        @Override
106        @SuppressWarnings("unchecked")
107        public UnivariateGaussianNaiveBayesModel<T> clone() {
108                try {
109                        return (UnivariateGaussianNaiveBayesModel<T>) super.clone();
110                } catch (final CloneNotSupportedException e) {
111                        throw new RuntimeException(e);
112                }
113        }
114
115        /**
116         * Get the class distribution for the given class.
117         * 
118         * @param clz
119         *            the class
120         * @return the univariate gaussian distribution.
121         */
122        public UnivariateGaussian getClassDistribution(T clz) {
123                return model.getConditionals().get(clz).get(0);
124        }
125
126        /**
127         * Get the class distribution for all classes.
128         * 
129         * @return a map of classes to distributions
130         */
131        public Map<T, UnivariateGaussian> getClassDistribution() {
132                final Map<T, UnivariateGaussian> clzs = new HashMap<T, UnivariateGaussian>();
133
134                for (final T c : model.getCategories()) {
135                        clzs.put(c, model.getConditionals().get(c).get(0));
136                }
137
138                return clzs;
139        }
140
141        /**
142         * @return The priors for each class
143         */
144        public DataHistogram<T> getClassPriors() {
145                return model.getPriors();
146        }
147
148        /**
149         * Testing
150         * 
151         * @param args
152         */
153        public static void main(String[] args) {
154                final UnivariateGaussianNaiveBayesModel<Boolean> model = new UnivariateGaussianNaiveBayesModel<Boolean>();
155
156                final List<IndependentPair<Double, Boolean>> data = new ArrayList<IndependentPair<Double, Boolean>>();
157
158                data.add(IndependentPair.pair(0.0, true));
159                data.add(IndependentPair.pair(0.1, true));
160                data.add(IndependentPair.pair(-0.1, true));
161
162                data.add(IndependentPair.pair(9.9, false));
163                data.add(IndependentPair.pair(10.0, false));
164                data.add(IndependentPair.pair(10.1, false));
165
166                model.estimate(data);
167
168                System.out.println(model.predict(5.1));
169
170                System.out.println(model.model.getConditionals().get(true));
171                System.out.println(model.model.getConditionals().get(false));
172
173                System.out.println(model.model.getConditionals().get(true).get(0).getMean());
174                System.out.println(model.model.getConditionals().get(true).get(0).getVariance());
175                System.out.println(model.model.getConditionals().get(false).get(0).getMean());
176                System.out.println(model.model.getConditionals().get(false).get(0).getVariance());
177
178                System.out.println(model.model.getPriors());
179        }
180}