001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.demos.sandbox.ml.linear.learner.stream; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033 034import java.util.ArrayList; 035import java.util.HashMap; 036import java.util.List; 037import java.util.Map; 038 039import org.openimaj.math.matrix.CFMatrixUtils; 040import org.openimaj.ml.linear.evaluation.BilinearEvaluator; 041import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner; 042import org.openimaj.ml.linear.learner.IncrementalBilinearSparseOnlineLearner; 043import org.openimaj.util.data.Context; 044import org.openimaj.util.pair.Pair; 045 046import com.google.common.collect.BiMap; 047 048/** 049 * Some stats of a model 050 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 051 * 052 */ 053public class ModelStats { 054 055 /** 056 * The score of the model estimating the the data before this round 057 */ 058 public double score; 059 /** 060 * The sorted important words of the model after this round 061 */ 062 public Map<String, SortedImportantWords> importantWords; 063 /** 064 * The current model 065 */ 066 public IncrementalBilinearSparseOnlineLearner learner; 067 068 /** 069 * the min/max params of each task 070 */ 071 public Map<String,Pair<Double>> taskWordMinMax; 072 /** 073 * The value of Y for this round of the model 074 */ 075 public Matrix correctY; 076 /** 077 * The estimated value of Y for this round of the model 078 */ 079 public Matrix estimatedY; 080 081 /** 082 * The bias from the model 083 */ 084 public Matrix bias; 085 /** 086 * The min and max values for each user 087 */ 088 public Map<String, Pair<Double>> userMinMax; 089 090 /** 091 * A new learner, no meaningful important words and a loss of 0 092 */ 093 public ModelStats() { 094 this.score = 0; 095 this.learner = null; 096 this.importantWords = new HashMap<String, SortedImportantWords>(); 097 this.taskWordMinMax = new HashMap<String, Pair<Double>>(); 098 } 099 100 /** 101 * The model and its associated loss 102 * @param eval 103 * @param learner 104 * @param in 105 */ 106 public ModelStats(BilinearEvaluator eval, IncrementalBilinearSparseOnlineLearner learner, Context in) { 107 108// IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> in = inaggr.getPayload(); 109 Map<String, Map<String, Double>> bagofwords = in.getTyped("bagofwords"); 110 Map<String, Double> averageticks = in.getTyped("averageticks"); 111 this.learner = learner; 112 this.learner.updateUserValues(bagofwords, averageticks); 113 BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner(); 114 115 // Evaluate the learner on the current data 116 eval.setLearner(bilinearLearner); 117 List<Pair<Matrix>> testList = new ArrayList<Pair<Matrix>>(); 118 Pair<Matrix> xy = this.learner.asMatrixPair(bagofwords,averageticks); 119 testList.add(xy); 120 this.score = eval.evaluate(testList); 121 122 // Extract other statistics 123 this.importantWords = importantWords(); 124 this.taskWordMinMax = minMaxWords(); 125 this.userMinMax = minMaxUsers(); 126 this.correctY = xy.secondObject(); 127 this.bias = bilinearLearner.getBias(); 128 this.estimatedY = bilinearLearner.predict(xy.firstObject()); 129 } 130 131 private Map<String, Pair<Double>> minMaxUsers() { 132 Map<String, Pair<Double>> ret = new HashMap<String, Pair<Double>>(); 133 if(this.learner == null) return ret; 134 BiMap<String, Integer> depvals = this.learner.getDependantValues(); 135 BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner(); 136 for (String task : depvals.keySet()) { 137 Integer taskCol = this.learner.getDependantValues().get(task); 138 ret.put( 139 task, 140 new Pair<Double>( 141 CFMatrixUtils.min(bilearner.getU().getColumn(taskCol)), 142 CFMatrixUtils.max(bilearner.getU().getColumn(taskCol)) 143 ) 144 ); 145 } 146 147 return ret; 148 } 149 150 private Map<String, Pair<Double>> minMaxWords() { 151 Map<String, Pair<Double>> ret = new HashMap<String, Pair<Double>>(); 152 if(this.learner == null) return ret; 153 BiMap<String, Integer> depvals = this.learner.getDependantValues(); 154 BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner(); 155 for (String task : depvals.keySet()) { 156 Integer taskCol = this.learner.getDependantValues().get(task); 157 ret.put( 158 task, 159 new Pair<Double>( 160 CFMatrixUtils.min(bilearner.getW().getColumn(taskCol)), 161 CFMatrixUtils.max(bilearner.getW().getColumn(taskCol)) 162 ) 163 ); 164 } 165 166 return ret; 167 } 168 169 private Map<String, SortedImportantWords> importantWords() { 170 Map<String, SortedImportantWords> ret = new HashMap<String, SortedImportantWords>(); 171 if(this.learner == null) return ret; 172 BiMap<String, Integer> depvals = this.learner.getDependantValues(); 173 BilinearSparseOnlineLearner bilearner = this.learner.getBilinearLearner(); 174 for (String task : depvals.keySet()) { 175 ret.put( 176 task, 177 new SortedImportantWords(task, learner, bilearner, 10) 178 ); 179 } 180 181 return ret; 182 } 183 184 public void printSummary() { 185 if(learner == null){ 186 System.out.println("No loss!"); 187 return; 188 } 189 System.out.println("Loss: " + this.score); 190 System.out.println("Important words: "); 191 BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner(); 192 BiMap<Integer, String> inversewords = this.learner.getVocabulary().inverse(); 193 for (String task : this.importantWords.keySet()) { 194 Pair<Double> minmax = this.taskWordMinMax.get(task); 195 SortedImportantWords sortedImportantWords = this.importantWords.get(task); 196 for (int wordIndex : sortedImportantWords.indexes) { 197 System.out.println("Word: " + inversewords.get(wordIndex) + " index " + wordIndex); 198 System.out.println(bilinearLearner.getW().getRow(wordIndex)); 199 } 200 System.out.printf("... %s (%1.4f->%1.4f) %s\n", 201 task, 202 minmax.firstObject(), 203 minmax.secondObject(), 204 sortedImportantWords 205 ); 206 } 207 System.out.println("User importance: "); 208 for (String task : this.importantWords.keySet()) { 209 Pair<Double> minmax = this.userMinMax.get(task); 210 System.out.printf("... %s (%1.4f->%1.4f)\n", 211 task, 212 minmax.firstObject(), 213 minmax.secondObject() 214 ); 215 } 216 System.out.println("Model Bias: \n" + this.bias); 217 System.out.println("Correct Y: \n" + this.correctY); 218 System.out.println("Estimated Y: \n" + this.estimatedY); 219 } 220 221}