001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.data; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ; 034 035import java.io.File; 036import java.io.IOException; 037import java.util.ArrayList; 038import java.util.HashMap; 039import java.util.HashSet; 040import java.util.List; 041import java.util.Map; 042import java.util.Set; 043 044import org.openimaj.util.filter.FilterUtils; 045import org.openimaj.util.function.Predicate; 046import org.openimaj.util.pair.Pair; 047 048import com.jmatio.io.MatFileReader; 049import com.jmatio.types.MLArray; 050import com.jmatio.types.MLCell; 051import com.jmatio.types.MLChar; 052import com.jmatio.types.MLDouble; 053import com.jmatio.types.MLSparse; 054 055/** 056 * Read data from bill's matlab file format 057 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 058 */ 059public class BillMatlabFileDataGenerator implements MatrixDataGenerator<Matrix> { 060 /** 061 * 062 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 063 */ 064 public static class Fold { 065 /** 066 * @param training 067 * @param test 068 * @param validation 069 */ 070 public Fold(int[] training, int[] test, int[] validation) { 071 this.training = training; 072 this.test = test; 073 this.validation = validation; 074 } 075 076 int[] training; 077 int[] test; 078 int[] validation; 079 } 080 081 /** 082 * The modes 083 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 084 */ 085 public enum Mode { 086 /** 087 * 088 */ 089 TRAINING { 090 @Override 091 public int[] indexes(Fold fold) { 092 return fold.training; 093 } 094 }, 095 /** 096 * 097 */ 098 TEST { 099 @Override 100 public int[] indexes(Fold fold) { 101 return fold.test; 102 } 103 }, 104 /** 105 * 106 */ 107 VALIDATION { 108 @Override 109 public int[] indexes(Fold fold) { 110 return fold.validation; 111 } 112 }, 113 /** 114 * 115 */ 116 ALL { 117 118 @Override 119 public int[] indexes(Fold fold) { 120 return null; 121 } 122 123 }; 124 /** 125 * @param fold 126 * @return the indexes of this fold 127 */ 128 public abstract int[] indexes(Fold fold); 129 } 130 131 private Map<String, MLArray> content; 132 private List<Fold> folds; 133 private int ndays; 134 private int nusers; 135 private int nwords; 136 private List<Matrix> dayWords; 137 private List<Matrix> dayPolls; 138 private int currentIndex; 139 private int ntasks; 140 private int[] indexes; 141 private Map<Integer, String> voc; 142 private String[] tasks; 143 private Set<Integer> keepIndex; 144 private Map<Integer, Integer> indexToVoc ; 145 private boolean filter; 146 147 148 String mainMatrixKey = "user_vsr_for_polls"; 149 150 /** 151 * @param matfile 152 * @param ndays 153 * @param filter 154 * @throws IOException 155 */ 156 public BillMatlabFileDataGenerator(File matfile, int ndays, boolean filter) 157 throws IOException 158 { 159 final MatFileReader reader = new MatFileReader(matfile); 160 this.ndays = ndays; 161 this.content = reader.getContent(); 162 this.currentIndex = 0; 163 this.filter = filter; 164 prepareVocabulary(); 165 prepareFolds(); 166 prepareDayUserWords(); 167 prepareDayPolls(); 168 169 } 170 171 /** 172 * @param matfile 173 * @param mainMatrixName 174 * @param polls 175 * @param ndays 176 * @param filter 177 * @throws IOException 178 */ 179 public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter) 180 throws IOException 181 { 182 MatFileReader reader = new MatFileReader(matfile); 183 this.mainMatrixKey = mainMatrixName; 184 this.ndays = ndays; 185 this.content = reader.getContent(); 186 this.currentIndex = 0; 187 this.filter = filter; 188 prepareVocabulary(); 189 prepareFolds(); 190 prepareDayUserWords(); 191 reader = new MatFileReader(polls); 192 this.content = reader.getContent(); 193 prepareDayPolls(); 194 this.content = null; 195 196 } 197 198 /** 199 * @param matfile 200 * @param mainMatrixName 201 * @param polls 202 * @param ndays 203 * @param filter 204 * @param folds 205 * @throws IOException 206 */ 207 public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter, List<Fold> folds) 208 throws IOException 209 { 210 MatFileReader reader = new MatFileReader(matfile); 211 this.mainMatrixKey = mainMatrixName; 212 this.ndays = ndays; 213 this.content = reader.getContent(); 214 this.currentIndex = 0; 215 this.filter = filter; 216 prepareVocabulary(); 217 this.folds = folds; 218 prepareDayUserWords(); 219 reader = new MatFileReader(polls); 220 this.content = reader.getContent(); 221 prepareDayPolls(); 222 this.content = null; 223 224 } 225 226 227 228 /** 229 * @return the vocabulary 230 */ 231 public Map<Integer, String> getVocabulary() { 232 return voc; 233 } 234 235 private void prepareVocabulary() { 236 this.keepIndex = new HashSet<Integer>(); 237 238 final MLDouble keepIndex = (MLDouble) this.content.get("voc_keep_terms_index"); 239 if(keepIndex != null){ 240 final double[] filterIndexArr = keepIndex.getArray()[0]; 241 242 for (final double d : filterIndexArr) { 243 this.keepIndex.add((int) d - 1); 244 } 245 246 } 247 248 final MLCell vocLoaded = (MLCell) this.content.get("voc"); 249 if(vocLoaded!=null){ 250 this.indexToVoc = new HashMap<Integer, Integer>(); 251 final ArrayList<MLArray> vocArr = vocLoaded.cells(); 252 int index = 0; 253 int vocIndex = 0; 254 this.voc = new HashMap<Integer, String>(); 255 for (final MLArray vocArrItem : vocArr) { 256 final MLChar vocChar = (MLChar) vocArrItem; 257 final String vocString = vocChar.getString(0); 258 if (filter && this.keepIndex.contains(index)) { 259 this.voc.put(vocIndex, vocString); 260 this.indexToVoc.put(index, vocIndex); 261 vocIndex++; 262 } 263 index++; 264 } 265 } else { 266 267 } 268 } 269 270 /** 271 * @param fold 272 * @param mode 273 */ 274 public void setFold(int fold, Mode mode) { 275 if (fold == -1) { 276 this.indexes = new int[this.dayWords.size()]; 277 for (int i = 0; i < indexes.length; i++) { 278 indexes[i] = i; 279 } 280 } 281 else { 282 final Fold f = this.folds.get(fold); 283 this.indexes = mode.indexes(f); 284 } 285 this.currentIndex = 0; 286 } 287 288 private void prepareDayPolls() { 289 final ArrayList<String> pollKeys = FilterUtils.filter(this.content.keySet(), 290 new Predicate<String>() { 291 292 @Override 293 public boolean test(String object) { 294 return object.endsWith("per_unique_extended"); 295 } 296 }); 297 this.ntasks = pollKeys.size(); 298 dayPolls = new ArrayList<Matrix>(); 299 for (int i = 0; i < this.ndays; i++) { 300 dayPolls.add(SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, 301 this.ntasks)); 302 } 303 304 this.tasks = new String[this.ntasks]; 305 306 for (int t = 0; t < this.ntasks; t++) { 307 final String pollKey = pollKeys.get(t); 308 this.tasks[t] = pollKey; 309 final MLDouble arr = (MLDouble) this.content.get(pollKey); 310 for (int i = 0; i < this.ndays; i++) { 311 final Matrix dayPoll = dayPolls.get(i); 312 dayPoll.setElement(0, t, arr.get(i, 0)); 313 } 314 } 315 } 316 317 /** 318 * @return the tasks 319 */ 320 public String[] getTasks() { 321 return this.tasks; 322 } 323 324 325 private void prepareDayUserWords() { 326 final MLSparse arr = (MLSparse) this.content.get(mainMatrixKey); 327 final Double[] realVals = arr.exportReal(); 328 final int[] rows = arr.getIR(); 329 final int[] cols = arr.getIC(); 330 if(voc == null){ 331 this.nwords = arr.getN(); 332 } 333 else{ 334 this.nwords = this.voc.size(); 335 } 336 this.nusers = arr.getM() / this.ndays; 337 dayWords = new ArrayList<Matrix>(); 338 for (int i = 0; i < ndays; i++) { 339 final Matrix userWord = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.nwords, this.nusers); 340 dayWords.add(userWord); 341 } 342 for (int i = 0; i < rows.length; i++) { 343 if (filter && !this.keepIndex.contains(cols[i])) 344 continue; 345 346 int wordIndex = cols[i]; 347 if(this.indexToVoc!=null){ 348 wordIndex = this.indexToVoc.get(wordIndex); 349 } 350 final int dayIndex = rows[i] / this.nusers; 351 final int userIndex = rows[i] - (dayIndex * this.nusers); 352 353 dayWords.get(dayIndex).setElement(wordIndex, userIndex, realVals[i]); 354 355 } 356 } 357 358 private void prepareFolds() { 359 360 final MLArray setfolds = this.content.get("set_fold"); 361 if(setfolds==null) return; 362 if (setfolds.isCell()) { 363 this.folds = new ArrayList<Fold>(); 364 final MLCell foldcells = (MLCell) setfolds; 365 final int nfolds = foldcells.getM(); 366 System.out.println(String.format("Found %d folds", nfolds)); 367 for (int i = 0; i < nfolds; i++) { 368 final MLDouble training = (MLDouble) foldcells.get(i, 0); 369 final MLDouble test = (MLDouble) foldcells.get(i, 1); 370 final MLDouble validation = (MLDouble) foldcells.get(i, 2); 371 final Fold f = new Fold(toIntArray(training), toIntArray(test), 372 toIntArray(validation)); 373 folds.add(f); 374 } 375 } else { 376 throw new RuntimeException( 377 "Can't find set_folds in expected format"); 378 } 379 } 380 381 private int[] toIntArray(MLDouble training) { 382 final int[] arr = new int[training.getN()]; 383 for (int i = 0; i < arr.length; i++) { 384 arr[i] = training.get(0, i).intValue(); 385 } 386 return arr; 387 } 388 389 @Override 390 public Pair<Matrix> generate() { 391 if (currentIndex >= this.indexes.length) 392 return null; 393 final int dayIndex = this.indexes[currentIndex]; 394 final Pair<Matrix> pair = new Pair<Matrix>(this.dayWords.get(dayIndex), this.dayPolls.get(dayIndex)); 395 currentIndex++; 396 return pair; 397 } 398 399 /** 400 * @return number of folds 401 */ 402 public int nFolds() { 403 return this.folds.size(); 404 } 405 406 /** 407 * @return {@link #generate()} until there is nothing to consume 408 */ 409 public List<Pair<Matrix>> generateAll() { 410 List<Pair<Matrix>> ret = new ArrayList<Pair<Matrix>>(); 411 Pair<Matrix> pair; 412 while((pair = generate()) != null){ 413 ret.add(pair); 414 } 415 return ret; 416 } 417}