001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.pgm.vb.lda.mle; 031 032import java.util.HashMap; 033import java.util.Map; 034 035import org.apache.commons.math.special.Gamma; 036import org.openimaj.math.util.MathUtils; 037import org.openimaj.pgm.util.Corpus; 038import org.openimaj.pgm.util.Document; 039import org.openimaj.util.array.SparseIntArray.Entry; 040 041/** 042 * An implementation of variational inference LDA which can be saved and loaded 043 * 044 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 045 * 046 */ 047public class LDALearner { 048 private int ntopics; 049 private Map<LDAConfig, Object> config = new HashMap<LDAConfig, Object>(); 050 051 enum LDAConfig { 052 MAX_ITERATIONS { 053 @Override 054 public Integer defaultValue() { 055 return 10; 056 } 057 }, 058 ALPHA { 059 @Override 060 public Double defaultValue() { 061 return 0.3d; 062 } 063 }, 064 VAR_MAX_ITERATIONS { 065 @Override 066 public Integer defaultValue() { 067 return 10; 068 } 069 }, 070 INIT_STRATEGY { 071 072 @Override 073 public LDABetaInitStrategy defaultValue() { 074 return new LDABetaInitStrategy.RandomBetaInit(); 075 } 076 077 }, 078 EM_CONVERGED { 079 080 @Override 081 public Double defaultValue() { 082 return 1e-5; 083 } 084 085 }, 086 VAR_EM_CONVERGED { 087 088 @Override 089 public Double defaultValue() { 090 return 1e-5; 091 } 092 093 }; 094 public abstract Object defaultValue(); 095 } 096 097 /** 098 * @param ntopics 099 */ 100 public LDALearner(int ntopics) { 101 this.ntopics = ntopics; 102 } 103 104 /** 105 * @param key 106 * @return the configuration parameter value 107 */ 108 @SuppressWarnings("unchecked") 109 public <T> T getConfig(LDAConfig key) { 110 final T val = (T) this.config.get(key); 111 if (val == null) 112 return (T) key.defaultValue(); 113 return val; 114 } 115 116 /** 117 * initiates the EM algorithm on documents in the corpus 118 * 119 * @param corpus 120 */ 121 public void estimate(Corpus corpus) { 122 performEM(corpus); 123 } 124 125 private void performEM(Corpus corpus) { 126 // some variables 127 final double initialAlpha = (Double) this.getConfig(LDAConfig.ALPHA); 128 final LDABetaInitStrategy initStrat = this.getConfig(LDAConfig.INIT_STRATEGY); 129 130 // initialise the first state 131 final LDAModel state = new LDAModel(this.ntopics); 132 state.prepare(corpus); 133 state.setAlpha(initialAlpha); 134 initStrat.initModel(state, corpus); 135 136 final LDAVariationlState vstate = new LDAVariationlState(state); 137 while (modelConverged(vstate.state)) { 138 final LDAModel nextState = vstate.state.newInstance(); 139 nextState.setAlpha(initialAlpha); 140 for (final Document doc : corpus.getDocuments()) { 141 vstate.prepare(doc); 142 performE(doc, vstate); // updates the variation parameters given 143 // the current Beta 144 performM(doc, vstate, nextState); // updates the nextState given 145 // the variational 146 // parameters 147 nextState.likelihood += vstate.likelihood; 148 } 149 nextState.iteration++; 150 vstate.state = nextState; 151 } 152 } 153 154 private LDAVariationlState performE(Document doc, LDAVariationlState vstate) { 155 vstate.prepare(doc); 156 while (!variationalStateConverged(vstate)) { 157 int docWordIndex = 0; 158 for (final Entry wordCount : doc.getVector().entries()) { 159 double phiSum = 0; 160 final int word = wordCount.index; 161 final int count = wordCount.value; 162 for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) { 163 vstate.oldphi[topicIndex] = vstate.phi[docWordIndex][topicIndex]; 164 // If this word has been seen in this class before 165 if (vstate.state.topicWord[topicIndex][docWordIndex] > 0) { 166 // Update phi 167 // Remember this phi is actually the same value for 168 // every instance of thisparticular word. 169 // Whenever phi is actually used there is likely to be a 170 // multiplication by the number of times this particular 171 // word appears in this document 172 // From eqn 16 in blei 2003 173 // The sum gamma cancels when the exact phi for a given 174 // word is calculated 175 final double logBeta = 176 Math.log(vstate.state.topicWord[topicIndex][word]) - 177 Math.log(vstate.state.topicTotal[topicIndex]); 178 vstate.phi[docWordIndex][topicIndex] = 179 logBeta + 180 Gamma.digamma(vstate.varGamma[topicIndex]); 181 } else { 182 // if not, \Beta_wi = ETA (very small) so log \Beta_wi 183 // ~= -100 (ETA = 10-34) 184 vstate.phi[docWordIndex][topicIndex] = Gamma.digamma(vstate.varGamma[topicIndex]) - 100; 185 } 186 if (topicIndex == 0) { 187 phiSum = vstate.phi[docWordIndex][topicIndex]; 188 } else { 189 // we need phiSum = Sum_K_i{phi}, log phiSum = log 190 // Sum_K_i{phi}. 191 // what we have is log phi 192 // we must calculate log (a + b) from log(a) and log(b). 193 // The normaliser for eqn 16 194 phiSum = MathUtils.logSum(phiSum, 195 vstate.phi[docWordIndex][topicIndex]); 196 } 197 } 198 for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) { 199 // Replace log phi with the normalised phi 200 // normalise a given word's phi summing over all i in eqn 16 201 vstate.phi[docWordIndex][topicIndex] = Math.exp( 202 vstate.phi[docWordIndex][topicIndex] - phiSum 203 ); 204 // update gamma incrementally (eqn 17 blei 2003) 205 // - take away the old phi, 206 // - add the new phi, 207 // - do this N times for the number of times this particular 208 // word appears in this document 209 vstate.varGamma[topicIndex] += count 210 * (vstate.phi[docWordIndex][topicIndex] - vstate.oldphi[topicIndex]); 211 } 212 docWordIndex++; 213 } 214 vstate.oldLikelihood = vstate.likelihood; 215 vstate.likelihood = computeLikelihood(doc, vstate); 216 vstate.iteration++; 217 } 218 return vstate; 219 } 220 221 private boolean modelConverged(LDAModel model) { 222 final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.EM_CONVERGED); 223 final int MAX_ITER = (Integer) this.getConfig(LDAConfig.MAX_ITERATIONS); 224 // if likelihood ~= oldLikelihood then this value will approach 0. 225 final double converged = (model.likelihood - model.oldLikelihood) / model.oldLikelihood; 226 final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (model.iteration <= 2)); 227 final boolean maxIterExceeded = model.iteration > MAX_ITER; 228 229 return liklihoodSettled || maxIterExceeded; 230 } 231 232 private boolean variationalStateConverged(LDAVariationlState vstate) { 233 final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.VAR_EM_CONVERGED); 234 final int MAX_ITER = (Integer) this.getConfig(LDAConfig.VAR_MAX_ITERATIONS); 235 // if likelihood ~= oldLikelihood then this value will approach 0. 236 final double converged = (vstate.likelihood - vstate.oldLikelihood) / vstate.oldLikelihood; 237 final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (vstate.iteration <= 2)); 238 final boolean maxIterExceeded = vstate.iteration > MAX_ITER; 239 240 return liklihoodSettled || maxIterExceeded; 241 } 242 243 /** 244 * Given the current state of the variational parameters, update the maximum 245 * liklihood beta parameter by updating its sufficient statistics 246 * 247 * @param d 248 * @param vstate 249 * @param nextState 250 */ 251 private void performM(Document d, LDAVariationlState vstate, LDAModel nextState) { 252 253 for (final Entry entry : d.values.entries()) { 254 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) { 255 final int wordIndex = entry.index; 256 final int count = entry.value; 257 nextState.incTopicWord(topicIndex, wordIndex, count * vstate.phi[wordIndex][topicIndex]); 258 nextState.incTopicTotal(topicIndex, count); 259 } 260 } 261 } 262 263 /** 264 * Calculates a lower bound for the log liklihood of a document given 265 * current parameters. When this is maximised it minimises the KL divergence 266 * between the the variation posterior and the true posterior. 267 * 268 * The derivation can be seen in the appendix of Blei's LDA paper 2003 269 * 270 * @param doc 271 * @param vstate 272 * @return the likelihood 273 */ 274 public double computeLikelihood(Document doc, LDAVariationlState vstate) { 275 double likelihood = 0; 276 277 // Prepare some variables we need 278 double sumVarGamma = 0; 279 double sumDiGamma = 0; 280 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) { 281 sumVarGamma += vstate.varGamma[topicIndex]; 282 vstate.digamma[topicIndex] = Gamma 283 .digamma(vstate.varGamma[topicIndex]); 284 sumDiGamma += vstate.digamma[topicIndex]; 285 } 286 // first we sum the parameters which don't rely on iteration through the 287 // classes or 288 // iteration through the documents 289 290 likelihood += Gamma.logGamma(vstate.state.alpha * ntopics) - // eqn (15) 291 // line 292 // 1 293 Gamma.logGamma(vstate.state.alpha) * ntopics + // eqn (15) line 294 // 1 295 Gamma.logGamma(sumVarGamma); // eqn (15) line 4 296 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) { 297 // Now add the things that just need an interation over k 298 // eqn (15) line 4 299 final double topicGammaDiff = vstate.digamma[topicIndex] - sumDiGamma; 300 likelihood += Gamma.logGamma(vstate.varGamma[topicIndex]) - (vstate.varGamma[topicIndex] - 1) 301 * topicGammaDiff; 302 int wordIndex = 0; 303 for (final Entry wordCount : doc.getVector().entries()) { 304 final int word = wordCount.index; 305 final int count = wordCount.value; 306 final double logBeta = Math.log( 307 vstate.state.topicWord[topicIndex][word]) - 308 Math.log(vstate.state.topicTotal[topicIndex] 309 ); 310 likelihood += 311 // Count because these sums are over N and 312 // the sum of the counts of each unique word is == N 313 count * ( 314 // Each of these lines happens to multiply by 315 // the current word's phi 316 vstate.phi[wordIndex][topicIndex] * ( 317 // eqn (15) line 2 318 topicGammaDiff + 319 // eqn (15) line 3 320 count * logBeta - 321 // eqn (15) line 5 322 Math.log(vstate.phi[wordIndex][topicIndex] 323 ) 324 ) 325 ); 326 wordIndex++; 327 } 328 } 329 return likelihood; 330 } 331}