001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.model; 031 032import gov.sandia.cognition.learning.algorithm.bayes.VectorNaiveBayesCategorizer; 033import gov.sandia.cognition.learning.data.DefaultInputOutputPair; 034import gov.sandia.cognition.learning.data.InputOutputPair; 035import gov.sandia.cognition.math.matrix.Vector; 036import gov.sandia.cognition.math.matrix.VectorFactory; 037import gov.sandia.cognition.statistics.DataHistogram; 038import gov.sandia.cognition.statistics.distribution.UnivariateGaussian; 039import gov.sandia.cognition.statistics.distribution.UnivariateGaussian.PDF; 040 041import java.util.ArrayList; 042import java.util.HashMap; 043import java.util.List; 044import java.util.Map; 045 046import org.openimaj.util.pair.IndependentPair; 047 048/** 049 * An implementation of a {@link EstimatableModel} that uses a 050 * {@link VectorNaiveBayesCategorizer} to associate a univariate (a 051 * {@link Double}) with a category. 052 * 053 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 054 * 055 * @param <T> 056 * The type of class/category predicted by the model 057 */ 058 059public class UnivariateGaussianNaiveBayesModel<T> implements EstimatableModel<Double, T> { 060 private VectorNaiveBayesCategorizer<T, PDF> model; 061 062 /** 063 * Default constructor. 064 */ 065 public UnivariateGaussianNaiveBayesModel() { 066 067 } 068 069 /** 070 * Construct with a pre-trained model. 071 * 072 * @param model 073 * the pre-trained model. 074 */ 075 public UnivariateGaussianNaiveBayesModel(VectorNaiveBayesCategorizer<T, PDF> model) { 076 this.model = model; 077 } 078 079 @Override 080 public boolean estimate(List<? extends IndependentPair<Double, T>> data) { 081 final VectorNaiveBayesCategorizer.BatchGaussianLearner<T> learner = new VectorNaiveBayesCategorizer.BatchGaussianLearner<T>(); 082 final List<InputOutputPair<Vector, T>> cfdata = new ArrayList<InputOutputPair<Vector, T>>(); 083 084 for (final IndependentPair<Double, T> d : data) { 085 final InputOutputPair<Vector, T> iop = new DefaultInputOutputPair<Vector, T>(VectorFactory.getDefault() 086 .createVector1D(d.firstObject()), d.secondObject()); 087 cfdata.add(iop); 088 } 089 090 model = learner.learn(cfdata); 091 092 return true; 093 } 094 095 @Override 096 public T predict(Double data) { 097 return model.evaluate(VectorFactory.getDefault().createVector1D(data)); 098 } 099 100 @Override 101 public int numItemsToEstimate() { 102 return 0; 103 } 104 105 @Override 106 @SuppressWarnings("unchecked") 107 public UnivariateGaussianNaiveBayesModel<T> clone() { 108 try { 109 return (UnivariateGaussianNaiveBayesModel<T>) super.clone(); 110 } catch (final CloneNotSupportedException e) { 111 throw new RuntimeException(e); 112 } 113 } 114 115 /** 116 * Get the class distribution for the given class. 117 * 118 * @param clz 119 * the class 120 * @return the univariate gaussian distribution. 121 */ 122 public UnivariateGaussian getClassDistribution(T clz) { 123 return model.getConditionals().get(clz).get(0); 124 } 125 126 /** 127 * Get the class distribution for all classes. 128 * 129 * @return a map of classes to distributions 130 */ 131 public Map<T, UnivariateGaussian> getClassDistribution() { 132 final Map<T, UnivariateGaussian> clzs = new HashMap<T, UnivariateGaussian>(); 133 134 for (final T c : model.getCategories()) { 135 clzs.put(c, model.getConditionals().get(c).get(0)); 136 } 137 138 return clzs; 139 } 140 141 /** 142 * @return The priors for each class 143 */ 144 public DataHistogram<T> getClassPriors() { 145 return model.getPriors(); 146 } 147 148 /** 149 * Testing 150 * 151 * @param args 152 */ 153 public static void main(String[] args) { 154 final UnivariateGaussianNaiveBayesModel<Boolean> model = new UnivariateGaussianNaiveBayesModel<Boolean>(); 155 156 final List<IndependentPair<Double, Boolean>> data = new ArrayList<IndependentPair<Double, Boolean>>(); 157 158 data.add(IndependentPair.pair(0.0, true)); 159 data.add(IndependentPair.pair(0.1, true)); 160 data.add(IndependentPair.pair(-0.1, true)); 161 162 data.add(IndependentPair.pair(9.9, false)); 163 data.add(IndependentPair.pair(10.0, false)); 164 data.add(IndependentPair.pair(10.1, false)); 165 166 model.estimate(data); 167 168 System.out.println(model.predict(5.1)); 169 170 System.out.println(model.model.getConditionals().get(true)); 171 System.out.println(model.model.getConditionals().get(false)); 172 173 System.out.println(model.model.getConditionals().get(true).get(0).getMean()); 174 System.out.println(model.model.getConditionals().get(true).get(0).getVariance()); 175 System.out.println(model.model.getConditionals().get(false).get(0).getMean()); 176 System.out.println(model.model.getConditionals().get(false).get(0).getVariance()); 177 178 System.out.println(model.model.getPriors()); 179 } 180}