001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.text.nlp.language; 031 032import java.io.IOException; 033import java.io.InputStream; 034import java.io.UnsupportedEncodingException; 035import java.util.HashMap; 036import java.util.Locale; 037import java.util.Map; 038import java.util.Random; 039 040import Jama.Matrix; 041 042/** 043 * Code to train, classify and generate language specific text by building a 044 * first order Markov chain. 045 * 046 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 047 * 048 */ 049public class MarkovChainLanguageModel { 050 051 private Map<Locale, Matrix> chains = new HashMap<Locale, Matrix>(); 052 private Map<Locale, long[]> chainCounts = new HashMap<Locale, long[]>(); 053 054 /** 055 * Generate a new empty markov chain language model 056 */ 057 public MarkovChainLanguageModel() { 058 chains = new HashMap<Locale, Matrix>(); 059 chainCounts = new HashMap<Locale, long[]>(); 060 } 061 062 /** 063 * 064 * Add an example to a language's markov chain 065 * 066 * @param language 067 * the language the example is being added to 068 * @param example 069 * the new example to learn from 070 * @param encoding 071 * the encoding of the example 072 * @throws UnsupportedEncodingException 073 */ 074 public void train(Locale language, String example, String encoding) throws UnsupportedEncodingException { 075 if (!chains.containsKey(language)) { 076 chains.put(language, new Matrix(256 + 1, 256 + 1)); 077 chainCounts.put(language, new long[256 + 1]); 078 } 079 080 final Matrix chain = chains.get(language); 081 final long[] chainCount = chainCounts.get(language); 082 final byte[] data = example.getBytes(encoding); 083 084 int currentIndex = 0; 085 final double[][] chainData = chain.getArray(); 086 for (final byte b : data) { 087 final int newIndex = (b & 0xff) + 1; 088 chainData[currentIndex][newIndex] = chainData[currentIndex][newIndex] + 1; 089 chainCount[currentIndex] += 1; 090 currentIndex = newIndex; 091 } 092 093 } 094 095 /** 096 * Train a given language on a stream of text 097 * 098 * @param language 099 * @param stream 100 * @throws IOException 101 */ 102 public void train(Locale language, InputStream stream) throws IOException { 103 if (!chains.containsKey(language)) { 104 chains.put(language, new Matrix(256 + 1, 256 + 1)); 105 chainCounts.put(language, new long[256 + 1]); 106 } 107 108 final Matrix chain = chains.get(language); 109 final long[] chainCount = chainCounts.get(language); 110 111 int currentIndex = 0; 112 final double[][] chainData = chain.getArray(); 113 int newIndex = -1; 114 while ((newIndex = stream.read()) != -1) { 115 newIndex += 1; 116 chainData[currentIndex][newIndex] = chainData[currentIndex][newIndex] + 1; 117 chainCount[currentIndex] += 1; 118 currentIndex = newIndex; 119 } 120 } 121 122 /** 123 * Generate a string using this model of the desired length 124 * 125 * @param language 126 * 127 * @param length 128 * @param encoding 129 * @return the generated string 130 * @throws UnsupportedEncodingException 131 */ 132 public String generate(Locale language, int length, String encoding) throws UnsupportedEncodingException { 133 134 final Matrix chain = this.chains.get(language); 135 if (chain == null) 136 return null; 137 final double[][] chainData = chain.getArray(); 138 final long[] chainCount = this.chainCounts.get(language); 139 140 int currentIndex = 0; 141 final byte[] newString = new byte[length]; 142 final Random r = new Random(); 143 for (int i = 0; i < length; i++) { 144 final double prob = r.nextDouble(); 145 final double[] currentLine = chainData[currentIndex]; 146 double probSum = 0.0; 147 int newIndex = 0; 148 // System.out.println("CURRENT STATE:" + (char)(currentIndex-1)); 149 while (probSum + (currentLine[newIndex] / chainCount[currentIndex]) < prob) { 150 final double probForIndex = (currentLine[newIndex++] / chainCount[currentIndex]); 151 // System.out.println(probForIndex); 152 // if(probForIndex > 0){ 153 // System.out.println("Prob to go to:" + (char)(newIndex-2) + 154 // " = " + probForIndex); 155 // } 156 probSum += probForIndex; 157 } 158 // System.out.println("NEW STATE:" + (char)(newIndex-1)); 159 newString[i] = (byte) (newIndex - 1); 160 currentIndex = newIndex; 161 } 162 163 return new String(newString, encoding); 164 } 165 166}