001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.sandbox.ml.linear.learner.stream;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033
034import java.util.ArrayList;
035import java.util.HashMap;
036import java.util.List;
037import java.util.Map;
038
039import org.openimaj.math.matrix.CFMatrixUtils;
040import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
041import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
042import org.openimaj.ml.linear.learner.IncrementalBilinearSparseOnlineLearner;
043import org.openimaj.util.data.Context;
044import org.openimaj.util.pair.Pair;
045
046import com.google.common.collect.BiMap;
047
048/**
049 * Some stats of a model
050 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
051 *
052 */
053public class ModelStats {
054
055        /**
056         * The score of the model estimating the the data before this round
057         */
058        public double score;
059        /**
060         * The sorted important words of the model after this round
061         */
062        public Map<String, SortedImportantWords> importantWords;
063        /**
064         * The current model
065         */
066        public IncrementalBilinearSparseOnlineLearner learner;
067
068        /**
069         * the min/max params of each task
070         */
071        public Map<String,Pair<Double>> taskWordMinMax;
072        /**
073         * The value of Y for this round of the model
074         */
075        public Matrix correctY;
076        /**
077         * The estimated value of Y for this round of the model
078         */
079        public Matrix estimatedY;
080
081        /**
082         * The bias from the model
083         */
084        public Matrix bias;
085        /**
086         * The min and max values for each user
087         */
088        public Map<String, Pair<Double>> userMinMax;
089
090        /**
091         * A new learner, no meaningful important words and a loss of 0
092         */
093        public ModelStats() {
094                this.score = 0;
095                this.learner = null;
096                this.importantWords = new HashMap<String, SortedImportantWords>();
097                this.taskWordMinMax = new HashMap<String, Pair<Double>>();
098        }
099
100        /**
101         * The model and its associated loss
102         * @param eval
103         * @param learner
104         * @param in
105         */
106        public ModelStats(BilinearEvaluator eval, IncrementalBilinearSparseOnlineLearner learner, Context in) {
107
108//              IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> in = inaggr.getPayload();
109                Map<String, Map<String, Double>> bagofwords = in.getTyped("bagofwords");
110                Map<String, Double> averageticks = in.getTyped("averageticks");
111                this.learner = learner;
112                this.learner.updateUserValues(bagofwords, averageticks);
113                BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
114
115                // Evaluate the learner on the current data
116                eval.setLearner(bilinearLearner);
117                List<Pair<Matrix>> testList = new ArrayList<Pair<Matrix>>();
118                Pair<Matrix> xy = this.learner.asMatrixPair(bagofwords,averageticks);
119                testList.add(xy);
120                this.score = eval.evaluate(testList);
121
122                // Extract other statistics
123                this.importantWords = importantWords();
124                this.taskWordMinMax = minMaxWords();
125                this.userMinMax = minMaxUsers();
126                this.correctY = xy.secondObject();
127                this.bias = bilinearLearner.getBias();
128                this.estimatedY = bilinearLearner.predict(xy.firstObject());
129        }
130
131        private Map<String, Pair<Double>> minMaxUsers() {
132                Map<String, Pair<Double>> ret = new HashMap<String, Pair<Double>>();
133                if(this.learner == null) return ret;
134                BiMap<String, Integer> depvals = this.learner.getDependantValues();
135                BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner();
136                for (String task : depvals.keySet()) {
137                        Integer taskCol = this.learner.getDependantValues().get(task);
138                        ret.put(
139                                task,
140                                new Pair<Double>(
141                                        CFMatrixUtils.min(bilearner.getU().getColumn(taskCol)),
142                                        CFMatrixUtils.max(bilearner.getU().getColumn(taskCol))
143                                )
144                        );
145                }
146
147                return ret;
148        }
149
150        private Map<String, Pair<Double>> minMaxWords() {
151                Map<String, Pair<Double>> ret = new HashMap<String, Pair<Double>>();
152                if(this.learner == null) return ret;
153                BiMap<String, Integer> depvals = this.learner.getDependantValues();
154                BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner();
155                for (String task : depvals.keySet()) {
156                        Integer taskCol = this.learner.getDependantValues().get(task);
157                        ret.put(
158                                task,
159                                new Pair<Double>(
160                                        CFMatrixUtils.min(bilearner.getW().getColumn(taskCol)),
161                                        CFMatrixUtils.max(bilearner.getW().getColumn(taskCol))
162                                )
163                        );
164                }
165
166                return ret;
167        }
168
169        private Map<String, SortedImportantWords> importantWords() {
170                Map<String, SortedImportantWords> ret = new HashMap<String, SortedImportantWords>();
171                if(this.learner == null) return ret;
172                BiMap<String, Integer> depvals = this.learner.getDependantValues();
173                BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner();
174                for (String task : depvals.keySet()) {
175                        ret.put(
176                                task,
177                                new SortedImportantWords(task, learner, bilearner, 10)
178                        );
179                }
180
181                return ret;
182        }
183
184        public void printSummary() {
185                if(learner == null){
186                        System.out.println("No loss!");
187                        return;
188                }
189                System.out.println("Loss: " + this.score);
190                System.out.println("Important words: ");
191                BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
192                BiMap<Integer, String> inversewords = this.learner.getVocabulary().inverse();
193                for (String task : this.importantWords.keySet()) {
194                        Pair<Double> minmax = this.taskWordMinMax.get(task);
195                        SortedImportantWords sortedImportantWords = this.importantWords.get(task);
196                        for (int wordIndex : sortedImportantWords.indexes) {
197                                System.out.println("Word: " + inversewords.get(wordIndex) + " index " + wordIndex);
198                                System.out.println(bilinearLearner.getW().getRow(wordIndex));
199                        }
200                        System.out.printf("... %s (%1.4f->%1.4f) %s\n",
201                                        task,
202                                        minmax.firstObject(),
203                                        minmax.secondObject(),
204                                        sortedImportantWords
205                                        );
206                }
207                System.out.println("User importance: ");
208                for (String task : this.importantWords.keySet()) {
209                        Pair<Double> minmax = this.userMinMax.get(task);
210                        System.out.printf("... %s (%1.4f->%1.4f)\n",
211                                        task,
212                                        minmax.firstObject(),
213                                        minmax.secondObject()
214                                        );
215                }
216                System.out.println("Model Bias: \n" + this.bias);
217                System.out.println("Correct Y: \n" + this.correctY);
218                System.out.println("Estimated Y: \n" + this.estimatedY);
219        }
220
221}