View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.pgm.vb.lda.mle;
31  
32  import java.util.HashMap;
33  import java.util.Map;
34  
35  import org.apache.commons.math.special.Gamma;
36  import org.openimaj.math.util.MathUtils;
37  import org.openimaj.pgm.util.Corpus;
38  import org.openimaj.pgm.util.Document;
39  import org.openimaj.util.array.SparseIntArray.Entry;
40  
41  /**
42   * An implementation of variational inference LDA which can be saved and loaded
43   * 
44   * @author Sina Samangooei (ss@ecs.soton.ac.uk)
45   * 
46   */
47  public class LDALearner {
48  	private int ntopics;
49  	private Map<LDAConfig, Object> config = new HashMap<LDAConfig, Object>();
50  
51  	enum LDAConfig {
52  		MAX_ITERATIONS {
53  			@Override
54  			public Integer defaultValue() {
55  				return 10;
56  			}
57  		},
58  		ALPHA {
59  			@Override
60  			public Double defaultValue() {
61  				return 0.3d;
62  			}
63  		},
64  		VAR_MAX_ITERATIONS {
65  			@Override
66  			public Integer defaultValue() {
67  				return 10;
68  			}
69  		},
70  		INIT_STRATEGY {
71  
72  			@Override
73  			public LDABetaInitStrategy defaultValue() {
74  				return new LDABetaInitStrategy.RandomBetaInit();
75  			}
76  
77  		},
78  		EM_CONVERGED {
79  
80  			@Override
81  			public Double defaultValue() {
82  				return 1e-5;
83  			}
84  
85  		},
86  		VAR_EM_CONVERGED {
87  
88  			@Override
89  			public Double defaultValue() {
90  				return 1e-5;
91  			}
92  
93  		};
94  		public abstract Object defaultValue();
95  	}
96  
97  	/**
98  	 * @param ntopics
99  	 */
100 	public LDALearner(int ntopics) {
101 		this.ntopics = ntopics;
102 	}
103 
104 	/**
105 	 * @param key
106 	 * @return the configuration parameter value
107 	 */
108 	@SuppressWarnings("unchecked")
109 	public <T> T getConfig(LDAConfig key) {
110 		final T val = (T) this.config.get(key);
111 		if (val == null)
112 			return (T) key.defaultValue();
113 		return val;
114 	}
115 
116 	/**
117 	 * initiates the EM algorithm on documents in the corpus
118 	 * 
119 	 * @param corpus
120 	 */
121 	public void estimate(Corpus corpus) {
122 		performEM(corpus);
123 	}
124 
125 	private void performEM(Corpus corpus) {
126 		// some variables
127 		final double initialAlpha = (Double) this.getConfig(LDAConfig.ALPHA);
128 		final LDABetaInitStrategy initStrat = this.getConfig(LDAConfig.INIT_STRATEGY);
129 
130 		// initialise the first state
131 		final LDAModel state = new LDAModel(this.ntopics);
132 		state.prepare(corpus);
133 		state.setAlpha(initialAlpha);
134 		initStrat.initModel(state, corpus);
135 
136 		final LDAVariationlState vstate = new LDAVariationlState(state);
137 		while (modelConverged(vstate.state)) {
138 			final LDAModel nextState = vstate.state.newInstance();
139 			nextState.setAlpha(initialAlpha);
140 			for (final Document doc : corpus.getDocuments()) {
141 				vstate.prepare(doc);
142 				performE(doc, vstate); // updates the variation parameters given
143 										// the current Beta
144 				performM(doc, vstate, nextState); // updates the nextState given
145 													// the variational
146 													// parameters
147 				nextState.likelihood += vstate.likelihood;
148 			}
149 			nextState.iteration++;
150 			vstate.state = nextState;
151 		}
152 	}
153 
154 	private LDAVariationlState performE(Document doc, LDAVariationlState vstate) {
155 		vstate.prepare(doc);
156 		while (!variationalStateConverged(vstate)) {
157 			int docWordIndex = 0;
158 			for (final Entry wordCount : doc.getVector().entries()) {
159 				double phiSum = 0;
160 				final int word = wordCount.index;
161 				final int count = wordCount.value;
162 				for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
163 					vstate.oldphi[topicIndex] = vstate.phi[docWordIndex][topicIndex];
164 					// If this word has been seen in this class before
165 					if (vstate.state.topicWord[topicIndex][docWordIndex] > 0) {
166 						// Update phi
167 						// Remember this phi is actually the same value for
168 						// every instance of thisparticular word.
169 						// Whenever phi is actually used there is likely to be a
170 						// multiplication by the number of times this particular
171 						// word appears in this document
172 						// From eqn 16 in blei 2003
173 						// The sum gamma cancels when the exact phi for a given
174 						// word is calculated
175 						final double logBeta =
176 								Math.log(vstate.state.topicWord[topicIndex][word]) -
177 										Math.log(vstate.state.topicTotal[topicIndex]);
178 						vstate.phi[docWordIndex][topicIndex] =
179 								logBeta +
180 										Gamma.digamma(vstate.varGamma[topicIndex]);
181 					} else {
182 						// if not, \Beta_wi = ETA (very small) so log \Beta_wi
183 						// ~= -100 (ETA = 10-34)
184 						vstate.phi[docWordIndex][topicIndex] = Gamma.digamma(vstate.varGamma[topicIndex]) - 100;
185 					}
186 					if (topicIndex == 0) {
187 						phiSum = vstate.phi[docWordIndex][topicIndex];
188 					} else {
189 						// we need phiSum = Sum_K_i{phi}, log phiSum = log
190 						// Sum_K_i{phi}.
191 						// what we have is log phi
192 						// we must calculate log (a + b) from log(a) and log(b).
193 						// The normaliser for eqn 16
194 						phiSum = MathUtils.logSum(phiSum,
195 								vstate.phi[docWordIndex][topicIndex]);
196 					}
197 				}
198 				for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
199 					// Replace log phi with the normalised phi
200 					// normalise a given word's phi summing over all i in eqn 16
201 					vstate.phi[docWordIndex][topicIndex] = Math.exp(
202 							vstate.phi[docWordIndex][topicIndex] - phiSum
203 							);
204 					// update gamma incrementally (eqn 17 blei 2003)
205 					// - take away the old phi,
206 					// - add the new phi,
207 					// - do this N times for the number of times this particular
208 					// word appears in this document
209 					vstate.varGamma[topicIndex] += count
210 							* (vstate.phi[docWordIndex][topicIndex] - vstate.oldphi[topicIndex]);
211 				}
212 				docWordIndex++;
213 			}
214 			vstate.oldLikelihood = vstate.likelihood;
215 			vstate.likelihood = computeLikelihood(doc, vstate);
216 			vstate.iteration++;
217 		}
218 		return vstate;
219 	}
220 
221 	private boolean modelConverged(LDAModel model) {
222 		final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.EM_CONVERGED);
223 		final int MAX_ITER = (Integer) this.getConfig(LDAConfig.MAX_ITERATIONS);
224 		// if likelihood ~= oldLikelihood then this value will approach 0.
225 		final double converged = (model.likelihood - model.oldLikelihood) / model.oldLikelihood;
226 		final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (model.iteration <= 2));
227 		final boolean maxIterExceeded = model.iteration > MAX_ITER;
228 
229 		return liklihoodSettled || maxIterExceeded;
230 	}
231 
232 	private boolean variationalStateConverged(LDAVariationlState vstate) {
233 		final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.VAR_EM_CONVERGED);
234 		final int MAX_ITER = (Integer) this.getConfig(LDAConfig.VAR_MAX_ITERATIONS);
235 		// if likelihood ~= oldLikelihood then this value will approach 0.
236 		final double converged = (vstate.likelihood - vstate.oldLikelihood) / vstate.oldLikelihood;
237 		final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (vstate.iteration <= 2));
238 		final boolean maxIterExceeded = vstate.iteration > MAX_ITER;
239 
240 		return liklihoodSettled || maxIterExceeded;
241 	}
242 
243 	/**
244 	 * Given the current state of the variational parameters, update the maximum
245 	 * liklihood beta parameter by updating its sufficient statistics
246 	 * 
247 	 * @param d
248 	 * @param vstate
249 	 * @param nextState
250 	 */
251 	private void performM(Document d, LDAVariationlState vstate, LDAModel nextState) {
252 
253 		for (final Entry entry : d.values.entries()) {
254 			for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
255 				final int wordIndex = entry.index;
256 				final int count = entry.value;
257 				nextState.incTopicWord(topicIndex, wordIndex, count * vstate.phi[wordIndex][topicIndex]);
258 				nextState.incTopicTotal(topicIndex, count);
259 			}
260 		}
261 	}
262 
263 	/**
264 	 * Calculates a lower bound for the log liklihood of a document given
265 	 * current parameters. When this is maximised it minimises the KL divergence
266 	 * between the the variation posterior and the true posterior.
267 	 * 
268 	 * The derivation can be seen in the appendix of Blei's LDA paper 2003
269 	 * 
270 	 * @param doc
271 	 * @param vstate
272 	 * @return the likelihood
273 	 */
274 	public double computeLikelihood(Document doc, LDAVariationlState vstate) {
275 		double likelihood = 0;
276 
277 		// Prepare some variables we need
278 		double sumVarGamma = 0;
279 		double sumDiGamma = 0;
280 		for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
281 			sumVarGamma += vstate.varGamma[topicIndex];
282 			vstate.digamma[topicIndex] = Gamma
283 					.digamma(vstate.varGamma[topicIndex]);
284 			sumDiGamma += vstate.digamma[topicIndex];
285 		}
286 		// first we sum the parameters which don't rely on iteration through the
287 		// classes or
288 		// iteration through the documents
289 
290 		likelihood += Gamma.logGamma(vstate.state.alpha * ntopics) - // eqn (15)
291 																		// line
292 																		// 1
293 				Gamma.logGamma(vstate.state.alpha) * ntopics + // eqn (15) line
294 																// 1
295 				Gamma.logGamma(sumVarGamma); // eqn (15) line 4
296 		for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
297 			// Now add the things that just need an interation over k
298 			// eqn (15) line 4
299 			final double topicGammaDiff = vstate.digamma[topicIndex] - sumDiGamma;
300 			likelihood += Gamma.logGamma(vstate.varGamma[topicIndex]) - (vstate.varGamma[topicIndex] - 1)
301 					* topicGammaDiff;
302 			int wordIndex = 0;
303 			for (final Entry wordCount : doc.getVector().entries()) {
304 				final int word = wordCount.index;
305 				final int count = wordCount.value;
306 				final double logBeta = Math.log(
307 						vstate.state.topicWord[topicIndex][word]) -
308 						Math.log(vstate.state.topicTotal[topicIndex]
309 								);
310 				likelihood +=
311 						// Count because these sums are over N and
312 						// the sum of the counts of each unique word is == N
313 						count * (
314 								// Each of these lines happens to multiply by
315 								// the current word's phi
316 								vstate.phi[wordIndex][topicIndex] * (
317 								// eqn (15) line 2
318 								topicGammaDiff +
319 										// eqn (15) line 3
320 										count * logBeta -
321 								// eqn (15) line 5
322 								Math.log(vstate.phi[wordIndex][topicIndex]
323 											)
324 								)
325 								);
326 				wordIndex++;
327 			}
328 		}
329 		return likelihood;
330 	}
331 }