001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.text.nlp.language; 031 032 033import gnu.trove.map.hash.TIntObjectHashMap; 034import gnu.trove.procedure.TIntObjectProcedure; 035 036import java.io.DataInput; 037import java.io.DataOutput; 038import java.io.IOException; 039import java.util.Arrays; 040import java.util.List; 041import java.util.Map; 042import java.util.Map.Entry; 043 044import no.uib.cipr.matrix.DenseMatrix; 045 046import org.openimaj.io.ReadWriteableBinary; 047import org.openimaj.io.wrappers.Readable2DArrayBinary; 048import org.openimaj.io.wrappers.ReadableArrayBinary; 049import org.openimaj.io.wrappers.Writeable2DArrayBinary; 050import org.openimaj.io.wrappers.WriteableArrayBinary; 051import org.openimaj.math.matrix.MatrixUtils; 052 053 054/** 055 * The data used by {@link LanguageDetector} 056 * 057 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 058 * 059 * 060 */ 061 062public class LanguageModel implements ReadWriteableBinary{ 063 DenseMatrix naiveBayesPC; // N x 1 064 DenseMatrix naiveBayesPTC; // N x M 065 String[] naiveBayesClasses; // the language classes 066 TIntObjectHashMap<int[]> tk_output; 067 int[] tk_nextmove; 068 int naiveBayesNFeats; 069 070 /** 071 * do nothing 072 */ 073 public LanguageModel(){} 074 075 /** 076 * @param languageModel 077 */ 078 @SuppressWarnings("unchecked") 079 public LanguageModel(Map<String,Object> languageModel){ 080 List<Double> nb_pc_list = (List<Double>) languageModel.get("nb_pc"); 081 double[][] nb_pc_darr = new double[1][nb_pc_list.size()]; 082 int i = 0; 083 for (double value : nb_pc_list) { 084 nb_pc_darr[0][i++] = value; 085 } 086 naiveBayesPC = new DenseMatrix(nb_pc_darr); 087 088 List<List<Double>> nb_ptc_list = (List<List<Double>>) languageModel.get("nb_ptc"); 089 double[][] nb_ptc_darr = new double[nb_ptc_list.size()][nb_ptc_list.get(0).size()]; 090 i = 0; 091 for (List<Double> row: nb_ptc_list) { 092 int j = 0; 093 for(double val : row){ 094 nb_ptc_darr[i][j++] = val; 095 } 096 i++; 097 } 098 naiveBayesPTC = new DenseMatrix(nb_ptc_darr); 099 100 this.naiveBayesNFeats = (naiveBayesPTC.numColumns() * naiveBayesPTC.numRows()) / naiveBayesPC.numColumns(); 101 102 List<String> nb_classes_list = (List<String>)languageModel.get("nb_classes"); 103 naiveBayesClasses = nb_classes_list.toArray(new String[nb_classes_list.size()]); 104 105 tk_output = new TIntObjectHashMap<int[]>(); 106 Map<String,List<Double>> tk_output_map = (Map<String, List<Double>>) languageModel.get("tk_outp"); 107 for (Entry<String,List<Double>> entry : tk_output_map .entrySet()) { 108 i = 0; 109 int[] entryArr = new int[entry.getValue().size()]; 110 for (double entryVal : entry.getValue()) { 111 entryArr[i++] = (int) entryVal; 112 } 113 tk_output.put(Integer.parseInt(entry.getKey()),entryArr ); 114 } 115 List<Double> tk_nextmove_list = (List<Double>) languageModel.get("tk_nextmove"); 116 tk_nextmove = new int[tk_nextmove_list.size()]; 117 i = 0; 118 for (double val : tk_nextmove_list) { 119 tk_nextmove[i++] = (int)val; 120 } 121 } 122 123 @Override 124 public void writeBinary(final DataOutput out) throws IOException { 125 new Writeable2DArrayBinary(MatrixUtils.mtjToDoubleArray(naiveBayesPC)).writeBinary(out); 126 new Writeable2DArrayBinary(MatrixUtils.mtjToDoubleArray(naiveBayesPTC)).writeBinary(out); 127 WriteableArrayBinary<String> stringWriter = new WriteableArrayBinary<String>(naiveBayesClasses) { 128 @Override 129 protected void writeValue(String v, DataOutput out) throws IOException { 130 out.writeUTF(v); 131 } 132 }; 133 stringWriter.writeBinary(out); 134 out.writeInt(tk_output.size()); 135 this.tk_output.forEachEntry(new TIntObjectProcedure<int[]>() { 136 137 @Override 138 public boolean execute(int key, int[] value) { 139 try { 140 out.writeInt(key); 141 out.writeInt(value.length); 142 for (int i : value) { 143 out.writeInt(i); 144 } 145 } catch (IOException e) { 146 return false; 147 } 148 return true; 149 } 150 }); 151 out.writeInt(this.tk_nextmove.length); 152 for (int nextmove : this.tk_nextmove) { 153 out.writeInt(nextmove); 154 } 155 } 156 157 @Override 158 public byte[] binaryHeader() { 159 return "LANGMODEL".getBytes(); 160 } 161 162 @Override 163 public void readBinary(DataInput in) throws IOException { 164 Readable2DArrayBinary matrixReader = new Readable2DArrayBinary(null); 165 matrixReader.readBinary(in); 166 naiveBayesPC = new DenseMatrix(matrixReader.value); 167 168 matrixReader.readBinary(in); 169 naiveBayesPTC = new DenseMatrix(matrixReader.value); 170 171 this.naiveBayesNFeats = (naiveBayesPTC.numColumns() * naiveBayesPTC.numRows()) / naiveBayesPC.numColumns(); 172 173 ReadableArrayBinary<String> readableClasses = new ReadableArrayBinary<String>(null){ 174 175 @Override 176 protected String readValue(DataInput in) throws IOException { 177 return in.readUTF(); 178 } 179 180 @Override 181 protected String[] createEmpty(int sz) throws IOException { 182 return new String[sz]; 183 } 184 }; 185 readableClasses.readBinary(in); 186 this.naiveBayesClasses = readableClasses.value; 187 188 int nTKOut = in.readInt(); 189 this.tk_output = new TIntObjectHashMap<int[]>(nTKOut); 190 for (int i = 0; i < nTKOut; i++) { 191 int key = in.readInt(); 192 int length = in.readInt(); 193 int[] data = new int[length]; 194 for (int j = 0; j < length; j++) { 195 data[j] = in.readInt(); 196 } 197 this.tk_output.put(key, data); 198 } 199 int nextMoveLength = in.readInt(); 200 this.tk_nextmove = new int[nextMoveLength]; 201 for (int i = 0; i < nextMoveLength; i++) { 202 this.tk_nextmove[i] = in.readInt(); 203 } 204 } 205 206 @Override 207 public boolean equals(Object other){ 208 if(!(other instanceof LanguageModel)) return false; 209 final LanguageModel that = (LanguageModel) other; 210 211 boolean equal = true; 212 equal = Arrays.deepEquals(this.naiveBayesClasses, that.naiveBayesClasses); if(!equal) return false; 213 equal = this.naiveBayesNFeats == that.naiveBayesNFeats; if(!equal) return false; 214 equal = Arrays.deepEquals(MatrixUtils.mtjToDoubleArray(this.naiveBayesPC),MatrixUtils.mtjToDoubleArray(that.naiveBayesPC)); if(!equal) return false; 215 equal = Arrays.deepEquals(MatrixUtils.mtjToDoubleArray(this.naiveBayesPTC),MatrixUtils.mtjToDoubleArray(that.naiveBayesPTC)); if(!equal) return false; 216 equal = Arrays.equals(this.tk_nextmove,that.tk_nextmove); if(!equal) return false; 217 equal = this.tk_output.forEachEntry(new TIntObjectProcedure<int[]>() { 218 219 @Override 220 public boolean execute(int key, int[] value) { 221 return Arrays.equals(value, that.tk_output.get(key)); 222 } 223 });if(!equal) return false; 224 return equal; 225 } 226}