001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.data;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
034
035import java.io.File;
036import java.io.IOException;
037import java.util.ArrayList;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.List;
041import java.util.Map;
042import java.util.Set;
043
044import org.openimaj.util.filter.FilterUtils;
045import org.openimaj.util.function.Predicate;
046import org.openimaj.util.pair.Pair;
047
048import com.jmatio.io.MatFileReader;
049import com.jmatio.types.MLArray;
050import com.jmatio.types.MLCell;
051import com.jmatio.types.MLChar;
052import com.jmatio.types.MLDouble;
053import com.jmatio.types.MLSparse;
054
055/**
056 * Read data from bill's matlab file format
057 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
058 */
059public class BillMatlabFileDataGenerator implements MatrixDataGenerator<Matrix> {
060        /**
061         *
062         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
063         */
064        public static class Fold {
065                /**
066                 * @param training
067                 * @param test
068                 * @param validation
069                 */
070                public Fold(int[] training, int[] test, int[] validation) {
071                        this.training = training;
072                        this.test = test;
073                        this.validation = validation;
074                }
075
076                int[] training;
077                int[] test;
078                int[] validation;
079        }
080
081        /**
082         * The modes
083         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
084         */
085        public enum Mode {
086                /**
087                 * 
088                 */
089                TRAINING {
090                        @Override
091                        public int[] indexes(Fold fold) {
092                                return fold.training;
093                        }
094                },
095                /**
096                 * 
097                 */
098                TEST {
099                        @Override
100                        public int[] indexes(Fold fold) {
101                                return fold.test;
102                        }
103                },
104                /**
105                 * 
106                 */
107                VALIDATION {
108                        @Override
109                        public int[] indexes(Fold fold) {
110                                return fold.validation;
111                        }
112                },
113                /**
114                 * 
115                 */
116                ALL {
117
118                        @Override
119                        public int[] indexes(Fold fold) {
120                                return null;
121                        }
122
123                };
124                /**
125                 * @param fold
126                 * @return the indexes of this fold
127                 */
128                public abstract int[] indexes(Fold fold);
129        }
130
131        private Map<String, MLArray> content;
132        private List<Fold> folds;
133        private int ndays;
134        private int nusers;
135        private int nwords;
136        private List<Matrix> dayWords;
137        private List<Matrix> dayPolls;
138        private int currentIndex;
139        private int ntasks;
140        private int[] indexes;
141        private Map<Integer, String> voc;
142        private String[] tasks;
143        private Set<Integer> keepIndex;
144        private Map<Integer, Integer> indexToVoc ;
145        private boolean filter;
146
147        
148        String mainMatrixKey = "user_vsr_for_polls";
149        
150        /**
151         * @param matfile
152         * @param ndays
153         * @param filter
154         * @throws IOException
155         */
156        public BillMatlabFileDataGenerator(File matfile, int ndays, boolean filter)
157                        throws IOException
158        {
159                final MatFileReader reader = new MatFileReader(matfile);
160                this.ndays = ndays;
161                this.content = reader.getContent();
162                this.currentIndex = 0;
163                this.filter = filter;
164                prepareVocabulary();
165                prepareFolds();
166                prepareDayUserWords();
167                prepareDayPolls();
168
169        }
170        
171        /**
172         * @param matfile
173         * @param mainMatrixName 
174         * @param polls 
175         * @param ndays
176         * @param filter
177         * @throws IOException
178         */
179        public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter)
180                        throws IOException
181        {
182                MatFileReader reader = new MatFileReader(matfile);
183                this.mainMatrixKey = mainMatrixName;
184                this.ndays = ndays;
185                this.content = reader.getContent();
186                this.currentIndex = 0;
187                this.filter = filter;
188                prepareVocabulary();
189                prepareFolds();
190                prepareDayUserWords();
191                reader = new MatFileReader(polls);
192                this.content = reader.getContent();
193                prepareDayPolls();
194                this.content = null;
195
196        }
197        
198        /**
199         * @param matfile
200         * @param mainMatrixName 
201         * @param polls 
202         * @param ndays
203         * @param filter
204         * @param folds 
205         * @throws IOException
206         */
207        public BillMatlabFileDataGenerator(File matfile, String mainMatrixName, File polls, int ndays, boolean filter, List<Fold> folds)
208                        throws IOException
209        {
210                MatFileReader reader = new MatFileReader(matfile);
211                this.mainMatrixKey = mainMatrixName;
212                this.ndays = ndays;
213                this.content = reader.getContent();
214                this.currentIndex = 0;
215                this.filter = filter;
216                prepareVocabulary();
217                this.folds = folds;
218                prepareDayUserWords();
219                reader = new MatFileReader(polls);
220                this.content = reader.getContent();
221                prepareDayPolls();
222                this.content = null;
223
224        }
225
226
227
228        /**
229         * @return the vocabulary
230         */
231        public Map<Integer, String> getVocabulary() {
232                return voc;
233        }
234
235        private void prepareVocabulary() {
236                this.keepIndex = new HashSet<Integer>();
237
238                final MLDouble keepIndex = (MLDouble) this.content.get("voc_keep_terms_index");
239                if(keepIndex != null){
240                        final double[] filterIndexArr = keepIndex.getArray()[0];
241                        
242                        for (final double d : filterIndexArr) {
243                                this.keepIndex.add((int) d - 1);
244                        }
245                        
246                }
247                
248                final MLCell vocLoaded = (MLCell) this.content.get("voc");
249                if(vocLoaded!=null){
250                        this.indexToVoc = new HashMap<Integer, Integer>();
251                        final ArrayList<MLArray> vocArr = vocLoaded.cells();
252                        int index = 0;
253                        int vocIndex = 0;
254                        this.voc = new HashMap<Integer, String>();
255                        for (final MLArray vocArrItem : vocArr) {
256                                final MLChar vocChar = (MLChar) vocArrItem;
257                                final String vocString = vocChar.getString(0);
258                                if (filter && this.keepIndex.contains(index)) {
259                                        this.voc.put(vocIndex, vocString);
260                                        this.indexToVoc.put(index, vocIndex);
261                                        vocIndex++;
262                                }
263                                index++;
264                        }
265                } else {
266                        
267                }
268        }
269
270        /**
271         * @param fold
272         * @param mode
273         */
274        public void setFold(int fold, Mode mode) {
275                if (fold == -1) {
276                        this.indexes = new int[this.dayWords.size()];
277                        for (int i = 0; i < indexes.length; i++) {
278                                indexes[i] = i;
279                        }
280                }
281                else {
282                        final Fold f = this.folds.get(fold);
283                        this.indexes = mode.indexes(f);
284                }
285                this.currentIndex = 0;
286        }
287
288        private void prepareDayPolls() {
289                final ArrayList<String> pollKeys = FilterUtils.filter(this.content.keySet(),
290                                new Predicate<String>() {
291
292                                        @Override
293                                        public boolean test(String object) {
294                                                return object.endsWith("per_unique_extended");
295                                        }
296                                });
297                this.ntasks = pollKeys.size();
298                dayPolls = new ArrayList<Matrix>();
299                for (int i = 0; i < this.ndays; i++) {
300                        dayPolls.add(SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1,
301                                        this.ntasks));
302                }
303
304                this.tasks = new String[this.ntasks];
305
306                for (int t = 0; t < this.ntasks; t++) {
307                        final String pollKey = pollKeys.get(t);
308                        this.tasks[t] = pollKey;
309                        final MLDouble arr = (MLDouble) this.content.get(pollKey);
310                        for (int i = 0; i < this.ndays; i++) {
311                                final Matrix dayPoll = dayPolls.get(i);
312                                dayPoll.setElement(0, t, arr.get(i, 0));
313                        }
314                }
315        }
316
317        /**
318         * @return the tasks
319         */
320        public String[] getTasks() {
321                return this.tasks;
322        }
323
324        
325        private void prepareDayUserWords() {
326                final MLSparse arr = (MLSparse) this.content.get(mainMatrixKey);
327                final Double[] realVals = arr.exportReal();
328                final int[] rows = arr.getIR();
329                final int[] cols = arr.getIC();
330                if(voc == null){
331                        this.nwords = arr.getN();
332                }
333                else{
334                        this.nwords = this.voc.size();                  
335                }
336                this.nusers = arr.getM() / this.ndays;
337                dayWords = new ArrayList<Matrix>();
338                for (int i = 0; i < ndays; i++) {
339                        final Matrix userWord = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.nwords, this.nusers);
340                        dayWords.add(userWord);
341                }
342                for (int i = 0; i < rows.length; i++) {
343                        if (filter && !this.keepIndex.contains(cols[i]))
344                                continue;
345                        
346                        int wordIndex = cols[i];
347                        if(this.indexToVoc!=null){
348                                wordIndex = this.indexToVoc.get(wordIndex);
349                        }
350                        final int dayIndex = rows[i] / this.nusers;
351                        final int userIndex = rows[i] - (dayIndex * this.nusers);
352
353                        dayWords.get(dayIndex).setElement(wordIndex, userIndex, realVals[i]);
354
355                }
356        }
357
358        private void prepareFolds() {
359
360                final MLArray setfolds = this.content.get("set_fold");
361                if(setfolds==null) return;
362                if (setfolds.isCell()) {
363                        this.folds = new ArrayList<Fold>();
364                        final MLCell foldcells = (MLCell) setfolds;
365                        final int nfolds = foldcells.getM();
366                        System.out.println(String.format("Found %d folds", nfolds));
367                        for (int i = 0; i < nfolds; i++) {
368                                final MLDouble training = (MLDouble) foldcells.get(i, 0);
369                                final MLDouble test = (MLDouble) foldcells.get(i, 1);
370                                final MLDouble validation = (MLDouble) foldcells.get(i, 2);
371                                final Fold f = new Fold(toIntArray(training), toIntArray(test),
372                                                toIntArray(validation));
373                                folds.add(f);
374                        }
375                } else {
376                        throw new RuntimeException(
377                                        "Can't find set_folds in expected format");
378                }
379        }
380
381        private int[] toIntArray(MLDouble training) {
382                final int[] arr = new int[training.getN()];
383                for (int i = 0; i < arr.length; i++) {
384                        arr[i] = training.get(0, i).intValue();
385                }
386                return arr;
387        }
388
389        @Override
390        public Pair<Matrix> generate() {
391                if (currentIndex >= this.indexes.length)
392                        return null;
393                final int dayIndex = this.indexes[currentIndex];
394                final Pair<Matrix> pair = new Pair<Matrix>(this.dayWords.get(dayIndex), this.dayPolls.get(dayIndex));
395                currentIndex++;
396                return pair;
397        }
398
399        /**
400         * @return number of folds
401         */
402        public int nFolds() {
403                return this.folds.size();
404        }
405
406        /**
407         * @return {@link #generate()} until there is nothing to consume
408         */
409        public List<Pair<Matrix>> generateAll() {
410                List<Pair<Matrix>> ret = new ArrayList<Pair<Matrix>>();
411                Pair<Matrix> pair;
412                while((pair = generate()) != null){
413                        ret.add(pair);
414                }
415                return ret;
416        }
417}