001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.pgm.vb.lda.mle;
031
032import java.util.HashMap;
033import java.util.Map;
034
035import org.apache.commons.math.special.Gamma;
036import org.openimaj.math.util.MathUtils;
037import org.openimaj.pgm.util.Corpus;
038import org.openimaj.pgm.util.Document;
039import org.openimaj.util.array.SparseIntArray.Entry;
040
041/**
042 * An implementation of variational inference LDA which can be saved and loaded
043 * 
044 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
045 * 
046 */
047public class LDALearner {
048        private int ntopics;
049        private Map<LDAConfig, Object> config = new HashMap<LDAConfig, Object>();
050
051        enum LDAConfig {
052                MAX_ITERATIONS {
053                        @Override
054                        public Integer defaultValue() {
055                                return 10;
056                        }
057                },
058                ALPHA {
059                        @Override
060                        public Double defaultValue() {
061                                return 0.3d;
062                        }
063                },
064                VAR_MAX_ITERATIONS {
065                        @Override
066                        public Integer defaultValue() {
067                                return 10;
068                        }
069                },
070                INIT_STRATEGY {
071
072                        @Override
073                        public LDABetaInitStrategy defaultValue() {
074                                return new LDABetaInitStrategy.RandomBetaInit();
075                        }
076
077                },
078                EM_CONVERGED {
079
080                        @Override
081                        public Double defaultValue() {
082                                return 1e-5;
083                        }
084
085                },
086                VAR_EM_CONVERGED {
087
088                        @Override
089                        public Double defaultValue() {
090                                return 1e-5;
091                        }
092
093                };
094                public abstract Object defaultValue();
095        }
096
097        /**
098         * @param ntopics
099         */
100        public LDALearner(int ntopics) {
101                this.ntopics = ntopics;
102        }
103
104        /**
105         * @param key
106         * @return the configuration parameter value
107         */
108        @SuppressWarnings("unchecked")
109        public <T> T getConfig(LDAConfig key) {
110                final T val = (T) this.config.get(key);
111                if (val == null)
112                        return (T) key.defaultValue();
113                return val;
114        }
115
116        /**
117         * initiates the EM algorithm on documents in the corpus
118         * 
119         * @param corpus
120         */
121        public void estimate(Corpus corpus) {
122                performEM(corpus);
123        }
124
125        private void performEM(Corpus corpus) {
126                // some variables
127                final double initialAlpha = (Double) this.getConfig(LDAConfig.ALPHA);
128                final LDABetaInitStrategy initStrat = this.getConfig(LDAConfig.INIT_STRATEGY);
129
130                // initialise the first state
131                final LDAModel state = new LDAModel(this.ntopics);
132                state.prepare(corpus);
133                state.setAlpha(initialAlpha);
134                initStrat.initModel(state, corpus);
135
136                final LDAVariationlState vstate = new LDAVariationlState(state);
137                while (modelConverged(vstate.state)) {
138                        final LDAModel nextState = vstate.state.newInstance();
139                        nextState.setAlpha(initialAlpha);
140                        for (final Document doc : corpus.getDocuments()) {
141                                vstate.prepare(doc);
142                                performE(doc, vstate); // updates the variation parameters given
143                                                                                // the current Beta
144                                performM(doc, vstate, nextState); // updates the nextState given
145                                                                                                        // the variational
146                                                                                                        // parameters
147                                nextState.likelihood += vstate.likelihood;
148                        }
149                        nextState.iteration++;
150                        vstate.state = nextState;
151                }
152        }
153
154        private LDAVariationlState performE(Document doc, LDAVariationlState vstate) {
155                vstate.prepare(doc);
156                while (!variationalStateConverged(vstate)) {
157                        int docWordIndex = 0;
158                        for (final Entry wordCount : doc.getVector().entries()) {
159                                double phiSum = 0;
160                                final int word = wordCount.index;
161                                final int count = wordCount.value;
162                                for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
163                                        vstate.oldphi[topicIndex] = vstate.phi[docWordIndex][topicIndex];
164                                        // If this word has been seen in this class before
165                                        if (vstate.state.topicWord[topicIndex][docWordIndex] > 0) {
166                                                // Update phi
167                                                // Remember this phi is actually the same value for
168                                                // every instance of thisparticular word.
169                                                // Whenever phi is actually used there is likely to be a
170                                                // multiplication by the number of times this particular
171                                                // word appears in this document
172                                                // From eqn 16 in blei 2003
173                                                // The sum gamma cancels when the exact phi for a given
174                                                // word is calculated
175                                                final double logBeta =
176                                                                Math.log(vstate.state.topicWord[topicIndex][word]) -
177                                                                                Math.log(vstate.state.topicTotal[topicIndex]);
178                                                vstate.phi[docWordIndex][topicIndex] =
179                                                                logBeta +
180                                                                                Gamma.digamma(vstate.varGamma[topicIndex]);
181                                        } else {
182                                                // if not, \Beta_wi = ETA (very small) so log \Beta_wi
183                                                // ~= -100 (ETA = 10-34)
184                                                vstate.phi[docWordIndex][topicIndex] = Gamma.digamma(vstate.varGamma[topicIndex]) - 100;
185                                        }
186                                        if (topicIndex == 0) {
187                                                phiSum = vstate.phi[docWordIndex][topicIndex];
188                                        } else {
189                                                // we need phiSum = Sum_K_i{phi}, log phiSum = log
190                                                // Sum_K_i{phi}.
191                                                // what we have is log phi
192                                                // we must calculate log (a + b) from log(a) and log(b).
193                                                // The normaliser for eqn 16
194                                                phiSum = MathUtils.logSum(phiSum,
195                                                                vstate.phi[docWordIndex][topicIndex]);
196                                        }
197                                }
198                                for (int topicIndex = 0; topicIndex < vstate.phi.length; topicIndex++) {
199                                        // Replace log phi with the normalised phi
200                                        // normalise a given word's phi summing over all i in eqn 16
201                                        vstate.phi[docWordIndex][topicIndex] = Math.exp(
202                                                        vstate.phi[docWordIndex][topicIndex] - phiSum
203                                                        );
204                                        // update gamma incrementally (eqn 17 blei 2003)
205                                        // - take away the old phi,
206                                        // - add the new phi,
207                                        // - do this N times for the number of times this particular
208                                        // word appears in this document
209                                        vstate.varGamma[topicIndex] += count
210                                                        * (vstate.phi[docWordIndex][topicIndex] - vstate.oldphi[topicIndex]);
211                                }
212                                docWordIndex++;
213                        }
214                        vstate.oldLikelihood = vstate.likelihood;
215                        vstate.likelihood = computeLikelihood(doc, vstate);
216                        vstate.iteration++;
217                }
218                return vstate;
219        }
220
221        private boolean modelConverged(LDAModel model) {
222                final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.EM_CONVERGED);
223                final int MAX_ITER = (Integer) this.getConfig(LDAConfig.MAX_ITERATIONS);
224                // if likelihood ~= oldLikelihood then this value will approach 0.
225                final double converged = (model.likelihood - model.oldLikelihood) / model.oldLikelihood;
226                final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (model.iteration <= 2));
227                final boolean maxIterExceeded = model.iteration > MAX_ITER;
228
229                return liklihoodSettled || maxIterExceeded;
230        }
231
232        private boolean variationalStateConverged(LDAVariationlState vstate) {
233                final double EM_CONVERGED = (Double) this.getConfig(LDAConfig.VAR_EM_CONVERGED);
234                final int MAX_ITER = (Integer) this.getConfig(LDAConfig.VAR_MAX_ITERATIONS);
235                // if likelihood ~= oldLikelihood then this value will approach 0.
236                final double converged = (vstate.likelihood - vstate.oldLikelihood) / vstate.oldLikelihood;
237                final boolean liklihoodSettled = ((converged < EM_CONVERGED) || (vstate.iteration <= 2));
238                final boolean maxIterExceeded = vstate.iteration > MAX_ITER;
239
240                return liklihoodSettled || maxIterExceeded;
241        }
242
243        /**
244         * Given the current state of the variational parameters, update the maximum
245         * liklihood beta parameter by updating its sufficient statistics
246         * 
247         * @param d
248         * @param vstate
249         * @param nextState
250         */
251        private void performM(Document d, LDAVariationlState vstate, LDAModel nextState) {
252
253                for (final Entry entry : d.values.entries()) {
254                        for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
255                                final int wordIndex = entry.index;
256                                final int count = entry.value;
257                                nextState.incTopicWord(topicIndex, wordIndex, count * vstate.phi[wordIndex][topicIndex]);
258                                nextState.incTopicTotal(topicIndex, count);
259                        }
260                }
261        }
262
263        /**
264         * Calculates a lower bound for the log liklihood of a document given
265         * current parameters. When this is maximised it minimises the KL divergence
266         * between the the variation posterior and the true posterior.
267         * 
268         * The derivation can be seen in the appendix of Blei's LDA paper 2003
269         * 
270         * @param doc
271         * @param vstate
272         * @return the likelihood
273         */
274        public double computeLikelihood(Document doc, LDAVariationlState vstate) {
275                double likelihood = 0;
276
277                // Prepare some variables we need
278                double sumVarGamma = 0;
279                double sumDiGamma = 0;
280                for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
281                        sumVarGamma += vstate.varGamma[topicIndex];
282                        vstate.digamma[topicIndex] = Gamma
283                                        .digamma(vstate.varGamma[topicIndex]);
284                        sumDiGamma += vstate.digamma[topicIndex];
285                }
286                // first we sum the parameters which don't rely on iteration through the
287                // classes or
288                // iteration through the documents
289
290                likelihood += Gamma.logGamma(vstate.state.alpha * ntopics) - // eqn (15)
291                                                                                                                                                // line
292                                                                                                                                                // 1
293                                Gamma.logGamma(vstate.state.alpha) * ntopics + // eqn (15) line
294                                                                                                                                // 1
295                                Gamma.logGamma(sumVarGamma); // eqn (15) line 4
296                for (int topicIndex = 0; topicIndex < ntopics; topicIndex++) {
297                        // Now add the things that just need an interation over k
298                        // eqn (15) line 4
299                        final double topicGammaDiff = vstate.digamma[topicIndex] - sumDiGamma;
300                        likelihood += Gamma.logGamma(vstate.varGamma[topicIndex]) - (vstate.varGamma[topicIndex] - 1)
301                                        * topicGammaDiff;
302                        int wordIndex = 0;
303                        for (final Entry wordCount : doc.getVector().entries()) {
304                                final int word = wordCount.index;
305                                final int count = wordCount.value;
306                                final double logBeta = Math.log(
307                                                vstate.state.topicWord[topicIndex][word]) -
308                                                Math.log(vstate.state.topicTotal[topicIndex]
309                                                                );
310                                likelihood +=
311                                                // Count because these sums are over N and
312                                                // the sum of the counts of each unique word is == N
313                                                count * (
314                                                                // Each of these lines happens to multiply by
315                                                                // the current word's phi
316                                                                vstate.phi[wordIndex][topicIndex] * (
317                                                                // eqn (15) line 2
318                                                                topicGammaDiff +
319                                                                                // eqn (15) line 3
320                                                                                count * logBeta -
321                                                                // eqn (15) line 5
322                                                                Math.log(vstate.phi[wordIndex][topicIndex]
323                                                                                        )
324                                                                )
325                                                                );
326                                wordIndex++;
327                        }
328                }
329                return likelihood;
330        }
331}