001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.text.nlp.language;
031
032import java.io.IOException;
033import java.io.InputStream;
034import java.io.UnsupportedEncodingException;
035import java.util.HashMap;
036import java.util.Locale;
037import java.util.Map;
038import java.util.Random;
039
040import Jama.Matrix;
041
042/**
043 * Code to train, classify and generate language specific text by building a
044 * first order Markov chain.
045 * 
046 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
047 * 
048 */
049public class MarkovChainLanguageModel {
050
051        private Map<Locale, Matrix> chains = new HashMap<Locale, Matrix>();
052        private Map<Locale, long[]> chainCounts = new HashMap<Locale, long[]>();
053
054        /**
055         * Generate a new empty markov chain language model
056         */
057        public MarkovChainLanguageModel() {
058                chains = new HashMap<Locale, Matrix>();
059                chainCounts = new HashMap<Locale, long[]>();
060        }
061
062        /**
063         * 
064         * Add an example to a language's markov chain
065         * 
066         * @param language
067         *            the language the example is being added to
068         * @param example
069         *            the new example to learn from
070         * @param encoding
071         *            the encoding of the example
072         * @throws UnsupportedEncodingException
073         */
074        public void train(Locale language, String example, String encoding) throws UnsupportedEncodingException {
075                if (!chains.containsKey(language)) {
076                        chains.put(language, new Matrix(256 + 1, 256 + 1));
077                        chainCounts.put(language, new long[256 + 1]);
078                }
079
080                final Matrix chain = chains.get(language);
081                final long[] chainCount = chainCounts.get(language);
082                final byte[] data = example.getBytes(encoding);
083
084                int currentIndex = 0;
085                final double[][] chainData = chain.getArray();
086                for (final byte b : data) {
087                        final int newIndex = (b & 0xff) + 1;
088                        chainData[currentIndex][newIndex] = chainData[currentIndex][newIndex] + 1;
089                        chainCount[currentIndex] += 1;
090                        currentIndex = newIndex;
091                }
092
093        }
094
095        /**
096         * Train a given language on a stream of text
097         * 
098         * @param language
099         * @param stream
100         * @throws IOException
101         */
102        public void train(Locale language, InputStream stream) throws IOException {
103                if (!chains.containsKey(language)) {
104                        chains.put(language, new Matrix(256 + 1, 256 + 1));
105                        chainCounts.put(language, new long[256 + 1]);
106                }
107
108                final Matrix chain = chains.get(language);
109                final long[] chainCount = chainCounts.get(language);
110
111                int currentIndex = 0;
112                final double[][] chainData = chain.getArray();
113                int newIndex = -1;
114                while ((newIndex = stream.read()) != -1) {
115                        newIndex += 1;
116                        chainData[currentIndex][newIndex] = chainData[currentIndex][newIndex] + 1;
117                        chainCount[currentIndex] += 1;
118                        currentIndex = newIndex;
119                }
120        }
121
122        /**
123         * Generate a string using this model of the desired length
124         * 
125         * @param language
126         * 
127         * @param length
128         * @param encoding
129         * @return the generated string
130         * @throws UnsupportedEncodingException
131         */
132        public String generate(Locale language, int length, String encoding) throws UnsupportedEncodingException {
133
134                final Matrix chain = this.chains.get(language);
135                if (chain == null)
136                        return null;
137                final double[][] chainData = chain.getArray();
138                final long[] chainCount = this.chainCounts.get(language);
139
140                int currentIndex = 0;
141                final byte[] newString = new byte[length];
142                final Random r = new Random();
143                for (int i = 0; i < length; i++) {
144                        final double prob = r.nextDouble();
145                        final double[] currentLine = chainData[currentIndex];
146                        double probSum = 0.0;
147                        int newIndex = 0;
148                        // System.out.println("CURRENT STATE:" + (char)(currentIndex-1));
149                        while (probSum + (currentLine[newIndex] / chainCount[currentIndex]) < prob) {
150                                final double probForIndex = (currentLine[newIndex++] / chainCount[currentIndex]);
151                                // System.out.println(probForIndex);
152                                // if(probForIndex > 0){
153                                // System.out.println("Prob to go to:" + (char)(newIndex-2) +
154                                // " = " + probForIndex);
155                                // }
156                                probSum += probForIndex;
157                        }
158                        // System.out.println("NEW STATE:" + (char)(newIndex-1));
159                        newString[i] = (byte) (newIndex - 1);
160                        currentIndex = newIndex;
161                }
162
163                return new String(newString, encoding);
164        }
165
166}