1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.pgm.vb.lda.mle;
31
32 import java.util.HashMap;
33 import java.util.Map;
34
35 import org.apache.commons.math.special.Gamma;
36 import org.openimaj.math.util.MathUtils;
37 import org.openimaj.pgm.util.Corpus;
38 import org.openimaj.pgm.util.Document;
39 import org.openimaj.util.array.SparseIntArray.Entry;
40
41
42
43
44
45
46
47 public class LDALearner {
48 private int ntopics;
49 private Map<LDAConfig, Object> config = new HashMap<LDAConfig, Object>();
50
51 enum LDAConfig {
52 MAX_ITERATIONS {
53 @Override
54 public Integer defaultValue() {
55 return 10;
56 }
57 },
58 ALPHA {
59 @Override
60 public Double defaultValue() {
61 return 0.3d;
62 }
63 },
64 VAR_MAX_ITERATIONS {
65 @Override
66 public Integer defaultValue() {
67 return 10;
68 }
69 },
70 INIT_STRATEGY {
71
72 @Override
73 public LDABetaInitStrategy defaultValue() {
74 return new LDABetaInitStrategy.RandomBetaInit();
75 }
76
77 },
78 EM_CONVERGED {
79
80 @Override
81 public Double defaultValue() {
82 return 1e-5;
83 }
84
85 },
86 VAR_EM_CONVERGED {
87
88 @Override
89 public Double defaultValue() {
90 return 1e-5;
91 }
92
93 };
94 public abstract Object defaultValue();
95 }
96
97
98
99
100 public LDALearner(int ntopics) {
101 this.ntopics = ntopics;
102 }
103
104
105
106
107
108 @SuppressWarnings("unchecked")
109 public <T> T getConfig(LDAConfig key) {
110 final T val = (T) this.config.get(key);
111 if (val == null)
112 return (T) key.defaultValue();
113 return val;
114 }
115
116
117
118
119
120
121 public void estimate(Corpus corpus) {
122 performEM(corpus);
123 }
124
125 private void performEM(Corpus corpus) {
126
127 final double initialAlpha = (Double) this.getConfig(LDAConfig.ALPHA);
128 final LDABetaInitStrategy initStrat = this.getConfig(LDAConfig.INIT_STRATEGY);
129
130
131 final LDAModel state = new LDAModel(this.ntopics);
132 state.prepare(corpus);
133 state.setAlpha(initialAlpha);
134 initStrat.initModel(state, corpus);
135
136 final LDAVariationlState vstate = new LDAVariationlState(state);
137 while (modelConverged(vstate.state)) {
138 final LDAModel nextState = vstate.state.newInstance();
139 nextState.setAlpha(initialAlpha);
140 for (final Document doc : corpus.getDocuments()) {
141 vstate.prepare(doc);
142 performE(doc, vstate);
143
144 performM(doc, vstate, nextState);
145
146
147 nextState.likelihood += vstate.likelihood;
148 }
149 nextState.iteration++;
150 vstate.state = nextState;
151 }
152 }
153
154 private LDAVariationlState performE(Document doc, LDAVariationlState vstate) {
155 vstate.prepare(doc);
156 while (!variationalStateConverged(vstate)) {
157 int docWordIndex = 0;
158 for (final Entry wordCount : doc.getVector().entries()) {
159 double phiSum = 0;
160 final int word = wordCount.index;
161 final int count = wordCount.value;
162 for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
163 vstate.oldphi[topicIndex] = vstate.phi[docWordIndex][topicIndex];
164
165 if (vstate.state.topicWord[topicIndex][docWordIndex] > 0) {
166
167
168
169
170
171
172
173
174
175 final double logBeta =
176 Math.log(vstate.state.topicWord[topicIndex][word]) -
177 Math.log(vstate.state.topicTotal[topicIndex]);
178 vstate.phi[docWordIndex][topicIndex] =
179 logBeta +
180 Gamma.digamma(vstate.varGamma[topicIndex]);
181 } else {
182
183
184 vstate.phi[docWordIndex][topicIndex] = Gamma.digamma(vstate.varGamma[topicIndex]) - 100;
185 }
186 if (topicIndex == 0) {
187 phiSum = vstate.phi[docWordIndex][topicIndex];
188 } else {
189
190
191
192
193
194 phiSum = MathUtils.logSum(phiSum,
195 vstate.phi[docWordIndex][topicIndex]);
196 }
197 }
198 for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
199
200
201 vstate.phi[docWordIndex][topicIndex] = Math.exp(
202 vstate.phi[docWordIndex][topicIndex] - phiSum
203 );
204
205
206
207
208
209 vstate.varGamma[topicIndex] += count
210 * (vstate.phi[docWordIndex][topicIndex] - vstate.oldphi[topicIndex]);
211 }
212 docWordIndex++;
213 }
214 vstate.oldLikelihood = vstate.likelihood;
215 vstate.likelihood = computeLikelihood(doc, vstate);
216 vstate.iteration++;
217 }
218 return vstate;
219 }
220
221 private boolean modelConverged(LDAModel model) {
222 final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.EM_CONVERGED);
223 final int MAX_ITER = (Integer) this.getConfig(LDAConfig.MAX_ITERATIONS);
224
225 final double converged = (model.likelihood - model.oldLikelihood) / model.oldLikelihood;
226 final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (model.iteration <= 2));
227 final boolean maxIterExceeded = model.iteration > MAX_ITER;
228
229 return liklihoodSettled || maxIterExceeded;
230 }
231
232 private boolean variationalStateConverged(LDAVariationlState vstate) {
233 final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.VAR_EM_CONVERGED);
234 final int MAX_ITER = (Integer) this.getConfig(LDAConfig.VAR_MAX_ITERATIONS);
235
236 final double converged = (vstate.likelihood - vstate.oldLikelihood) / vstate.oldLikelihood;
237 final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (vstate.iteration <= 2));
238 final boolean maxIterExceeded = vstate.iteration > MAX_ITER;
239
240 return liklihoodSettled || maxIterExceeded;
241 }
242
243
244
245
246
247
248
249
250
251 private void performM(Document d, LDAVariationlState vstate, LDAModel nextState) {
252
253 for (final Entry entry : d.values.entries()) {
254 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
255 final int wordIndex = entry.index;
256 final int count = entry.value;
257 nextState.incTopicWord(topicIndex, wordIndex, count * vstate.phi[wordIndex][topicIndex]);
258 nextState.incTopicTotal(topicIndex, count);
259 }
260 }
261 }
262
263
264
265
266
267
268
269
270
271
272
273
274 public double computeLikelihood(Document doc, LDAVariationlState vstate) {
275 double likelihood = 0;
276
277
278 double sumVarGamma = 0;
279 double sumDiGamma = 0;
280 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
281 sumVarGamma += vstate.varGamma[topicIndex];
282 vstate.digamma[topicIndex] = Gamma
283 .digamma(vstate.varGamma[topicIndex]);
284 sumDiGamma += vstate.digamma[topicIndex];
285 }
286
287
288
289
290 likelihood += Gamma.logGamma(vstate.state.alpha * ntopics) -
291
292
293 Gamma.logGamma(vstate.state.alpha) * ntopics +
294
295 Gamma.logGamma(sumVarGamma);
296 for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
297
298
299 final double topicGammaDiff = vstate.digamma[topicIndex] - sumDiGamma;
300 likelihood += Gamma.logGamma(vstate.varGamma[topicIndex]) - (vstate.varGamma[topicIndex] - 1)
301 * topicGammaDiff;
302 int wordIndex = 0;
303 for (final Entry wordCount : doc.getVector().entries()) {
304 final int word = wordCount.index;
305 final int count = wordCount.value;
306 final double logBeta = Math.log(
307 vstate.state.topicWord[topicIndex][word]) -
308 Math.log(vstate.state.topicTotal[topicIndex]
309 );
310 likelihood +=
311
312
313 count * (
314
315
316 vstate.phi[wordIndex][topicIndex] * (
317
318 topicGammaDiff +
319
320 count * logBeta -
321
322 Math.log(vstate.phi[wordIndex][topicIndex]
323 )
324 )
325 );
326 wordIndex++;
327 }
328 }
329 return likelihood;
330 }
331 }