001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ; 034 035import java.util.HashMap; 036import java.util.Map; 037import java.util.Map.Entry; 038 039import org.openimaj.util.pair.IndependentPair; 040import org.openimaj.util.pair.Pair; 041 042import com.google.common.collect.BiMap; 043import com.google.common.collect.HashBiMap; 044 045/** 046 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 047 * 048 */ 049public class IncrementalBilinearSparseOnlineLearner 050 implements 051 OnlineLearner<Map<String, Map<String, Double>>, Map<String, Double>> 052{ 053 static class IncrementalBilinearSparseOnlineLearnerParams extends BilinearLearnerParameters { 054 055 /** 056 * 057 */ 058 private static final long serialVersionUID = -1847045895118918210L; 059 060 } 061 062 private BiMap<String, Integer> vocabulary; 063 private BiMap<String, Integer> users; 064 private BiMap<String, Integer> values; 065 private BilinearSparseOnlineLearner bilinearLearner; 066 private BilinearLearnerParameters params; 067 068 /** 069 * Instantiates with the default params 070 */ 071 public IncrementalBilinearSparseOnlineLearner() { 072 init(new IncrementalBilinearSparseOnlineLearnerParams()); 073 } 074 075 /** 076 * @param params 077 */ 078 public IncrementalBilinearSparseOnlineLearner(BilinearLearnerParameters params) { 079 init(params); 080 } 081 082 /** 083 * 084 */ 085 public void reinitParams() { 086 init(this.params); 087 088 } 089 090 private void init(BilinearLearnerParameters params) { 091 vocabulary = HashBiMap.create(); 092 users = HashBiMap.create(); 093 values = HashBiMap.create(); 094 this.params = params; 095 bilinearLearner = new BilinearSparseOnlineLearner(params); 096 } 097 098 /** 099 * @return the current parameters 100 */ 101 public BilinearLearnerParameters getParams() { 102 return this.params; 103 } 104 105 @Override 106 public void process(Map<String, Map<String, Double>> x, Map<String, Double> y) { 107 updateUserValues(x, y); 108 final Matrix yMat = constructYMatrix(y); 109 final Matrix xMat = constructXMatrix(x); 110 111 this.bilinearLearner.process(xMat, yMat); 112 } 113 114 /** 115 * Update the incremental learner and underlying weight matricies to reflect 116 * potentially novel users, words and values to learn against 117 * 118 * @param x 119 * @param y 120 */ 121 public void updateUserValues(Map<String, Map<String, Double>> x, Map<String, Double> y) { 122 updateUserWords(x); 123 updateValues(y); 124 } 125 126 private void updateValues(Map<String, Double> y) { 127 for (final String value : y.keySet()) { 128 if (!values.containsKey(value)) { 129 values.put(value, values.size()); 130 } 131 } 132 } 133 134 private Matrix constructYMatrix(Map<String, Double> y) { 135 final Matrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, values.size()); 136 for (final Entry<String, Double> ent : y.entrySet()) { 137 mat.setElement(0, values.get(ent.getKey()), ent.getValue()); 138 } 139 return mat; 140 } 141 142 private Map<String, Double> constructYMap(Matrix y) { 143 final Map<String, Double> ret = new HashMap<String, Double>(); 144 for (final String key : values.keySet()) { 145 final Integer index = values.get(key); 146 final double yvalue = y.getElement(0, index); 147 ret.put(key, yvalue); 148 } 149 return ret; 150 } 151 152 private Matrix constructXMatrix(Map<String, Map<String, Double>> x) { 153 final Matrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(vocabulary.size(), users.size()); 154 for (final Entry<String, Map<String, Double>> userwords : x.entrySet()) { 155 final int userindex = this.users.get(userwords.getKey()); 156 for (final Entry<String, Double> ent : userwords.getValue().entrySet()) { 157 mat.setElement(vocabulary.get(ent.getKey()), userindex, ent.getValue()); 158 } 159 } 160 return mat; 161 } 162 163 private void updateUserWords(Map<String, Map<String, Double>> x) { 164 int newUsers = 0; 165 int newWords = 0; 166 for (final Entry<String, Map<String, Double>> userWords : x.entrySet()) { 167 final String user = userWords.getKey(); 168 if (!users.containsKey(user)) { 169 users.put(user, users.size()); 170 newUsers++; 171 } 172 newWords += updateWords(userWords.getValue()); 173 } 174 175 this.bilinearLearner.addU(newUsers); 176 this.bilinearLearner.addW(newWords); 177 } 178 179 private int updateWords(Map<String, Double> value) { 180 int newWords = 0; 181 for (final String word : value.keySet()) { 182 if (!vocabulary.containsKey(word)) { 183 vocabulary.put(word, vocabulary.size()); 184 newWords++; 185 } 186 } 187 return newWords; 188 } 189 190 /** 191 * Construct a learner with the desired number of users and words. If users 192 * and words beyond those in the current model are asked for their 193 * parameters are set to 0 194 * 195 * @param nusers 196 * @param nwords 197 * @return a new {@link BilinearSparseOnlineLearner} 198 */ 199 public BilinearSparseOnlineLearner getBilinearLearner(int nusers, int nwords) { 200 final BilinearSparseOnlineLearner ret = this.bilinearLearner.clone(); 201 202 final Matrix u = ret.getU(); 203 final Matrix w = ret.getW(); 204 final Matrix newu = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, u.getNumColumns()); 205 final Matrix neww = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, w.getNumColumns()); 206 207 newu.setSubMatrix(0, 0, u); 208 neww.setSubMatrix(0, 0, w); 209 210 ret.setU(newu); 211 ret.setW(neww); 212 return ret; 213 } 214 215 /** 216 * @return the underlying {@link BilinearSparseOnlineLearner} with the 217 * current number of users and words 218 */ 219 public BilinearSparseOnlineLearner getBilinearLearner() { 220 return this.bilinearLearner.clone(); 221 } 222 223 /** 224 * Given a sparse pair of user/words and value construct a pair of matricies 225 * using the current mappings of words and users to matrix rows. 226 * 227 * Note: if users or words which have not yet be 228 * 229 * @param xy 230 * @param nfeatures 231 * the number of words total in the returned X matrix 232 * @param nusers 233 * the number of users total in the returned X matrix 234 * @param ntasks 235 * the number of tasks in the returned Y matrix 236 * @return the matrix pair representing the X and Y input 237 */ 238 public Pair<Matrix> asMatrixPair( 239 IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> xy, 240 int nfeatures, int nusers, int ntasks) 241 { 242 final Matrix y = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, ntasks); 243 final Matrix x = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nfeatures, nusers); 244 final Map<String, Double> ymap = xy.secondObject(); 245 final Map<String, Map<String, Double>> userFeatureMap = xy.firstObject(); 246 for (final Entry<String, Double> yent : ymap.entrySet()) { 247 y.setElement(0, this.values.get(yent.getKey()), yent.getValue()); 248 } 249 for (final Entry<String, Map<String, Double>> xent : userFeatureMap.entrySet()) { 250 final int userind = this.users.get(xent.getKey()); 251 for (final Entry<String, Double> fent : xent.getValue().entrySet()) { 252 x.setElement(this.vocabulary.get(fent.getKey()), userind, fent.getValue()); 253 } 254 } 255 return new Pair<Matrix>(x, y); 256 } 257 258 @Override 259 public Map<String, Double> predict(Map<String, Map<String, Double>> x) { 260 final Matrix xMat = constructXMatrix(x); 261 final Matrix yMat = this.bilinearLearner.predict(xMat); 262 return this.constructYMap(yMat); 263 } 264 265 /** 266 * @return the vocabulary 267 */ 268 public BiMap<String, Integer> getVocabulary() { 269 return vocabulary; 270 } 271 272 /** 273 * @param in 274 * @return calls {@link #asMatrixPair(IndependentPair, int, int, int)} with 275 * the current number of words, users and value 276 */ 277 public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> in) { 278 return this.asMatrixPair(in, this.vocabulary.size(), this.users.size(), this.values.size()); 279 } 280 281 /** 282 * @param x 283 * @param y 284 * @return calls {@link #asMatrixPair(IndependentPair, int, int, int)} with 285 * the current number of words, users and value 286 */ 287 public Pair<Matrix> asMatrixPair(Map<String, Map<String, Double>> x, Map<String, Double> y) { 288 return this.asMatrixPair(IndependentPair.pair(x, y), this.vocabulary.size(), this.users.size(), 289 this.values.size()); 290 } 291 292 /** 293 * @return the current map of dependent values to indexes 294 */ 295 public BiMap<String, Integer> getDependantValues() { 296 return this.values; 297 } 298 299 public BiMap<String, Integer> getUsers() { 300 return this.users; 301 } 302}