001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner.matlib; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035 036import org.apache.logging.log4j.Logger; 037import org.apache.logging.log4j.LogManager; 038 039import org.openimaj.io.ReadWriteableBinary; 040import org.openimaj.math.matrix.DiagonalMatrix; 041import org.openimaj.math.matrix.MatlibMatrixUtils; 042import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 043import org.openimaj.ml.linear.learner.OnlineLearner; 044import org.openimaj.ml.linear.learner.matlib.init.InitStrategy; 045import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat; 046import org.openimaj.ml.linear.learner.matlib.loss.LossFunction; 047import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction; 048import org.openimaj.ml.linear.learner.matlib.regul.Regulariser; 049 050import ch.akuhn.matrix.Matrix; 051import ch.akuhn.matrix.SparseMatrix; 052 053 054/** 055 * An implementation of a stochastic gradient decent with proximal perameter adjustment 056 * (for regularised parameters). 057 * 058 * Data is dealt with sequentially using a one pass implementation of the 059 * online proximal algorithm described in chapter 9 and 10 of: 060 * The Geometry of Constrained Structured Prediction: Applications to Inference and 061 * Learning of Natural Language Syntax, PhD, Andre T. Martins 062 * 063 * The implementation does the following: 064 * - When an X,Y is recieved: 065 * - Update currently held batch 066 * - If the batch is full: 067 * - While There is a great deal of change in U and W: 068 * - Calculate the gradient of W holding U fixed 069 * - Proximal update of W 070 * - Calculate the gradient of U holding W fixed 071 * - Proximal update of U 072 * - Calculate the gradient of Bias holding U and W fixed 073 * - flush the batch 074 * - return current U and W (same as last time is batch isn't filled yet) 075 * 076 * 077 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 078 * 079 */ 080public class MatlibBilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{ 081 082 static Logger logger = LogManager.getLogger(MatlibBilinearSparseOnlineLearner.class); 083 084 protected BilinearLearnerParameters params; 085 protected Matrix w; 086 protected Matrix u; 087 protected LossFunction loss; 088 protected Regulariser regul; 089 protected Double lambda_w,lambda_u; 090 protected Boolean biasMode; 091 protected Matrix bias; 092 protected Matrix diagX; 093 protected Double eta0_u; 094 protected Double eta0_w; 095 096 private Boolean forceSparcity; 097 098 private Boolean zStandardise; 099 100 private boolean nodataseen; 101 102 /** 103 * The default parameters. These won't work with your dataset, i promise. 104 */ 105 public MatlibBilinearSparseOnlineLearner() { 106 this(new BilinearLearnerParameters()); 107 } 108 /** 109 * @param params the parameters used by this learner 110 */ 111 public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters params) { 112 this.params = params; 113 reinitParams(); 114 } 115 116 /** 117 * must be called if any parameters are changed 118 */ 119 public void reinitParams() { 120 this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS); 121 this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL); 122 this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W); 123 this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U); 124 this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS); 125 this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U); 126 this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W); 127 this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY); 128 this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE); 129 if(!this.loss.isMatrixLoss()) 130 this.loss = new MatLossFunction(this.loss); 131 this.nodataseen = true; 132 } 133 private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) { 134 final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y); 135 final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y); 136 this.w = wstrat.init(xrows, ycols); 137 this.u = ustrat.init(xcols, ycols); 138 139 this.bias = SparseMatrix.sparse(ycols,ycols); 140 if(this.biasMode){ 141 final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y); 142 this.bias = bstrat.init(ycols, ycols); 143 this.diagX = new DiagonalMatrix(ycols,1); 144 } 145 } 146 147 private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) { 148 final InitStrategy strat = this.params.getTyped(initstrat); 149 return strat; 150 } 151 @Override 152 public void process(Matrix X, Matrix Y){ 153 final int nfeatures = X.rowCount(); 154 final int nusers = X.columnCount(); 155 final int ntasks = Y.columnCount(); 156// int ninstances = Y.rowCount(); // Assume 1 instance! 157 158 // only inits when the current params is null 159 if (this.w == null){ 160 initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks 161 } 162 163 final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING); 164 final double weighting = 1.0 - dampening ; 165 166 logger.debug("... dampening w, u and bias by: " + weighting); 167 168 // Adjust for weighting 169 MatlibMatrixUtils.scaleInplace(this.w,weighting); 170 MatlibMatrixUtils.scaleInplace(this.u,weighting); 171 if(this.biasMode){ 172 MatlibMatrixUtils.scaleInplace(this.bias,weighting); 173 } 174 // First expand Y s.t. blocks of rows contain the task values for each row of Y. 175 // This means Yexp has (n * t x t) 176 final SparseMatrix Yexp = expandY(Y); 177 loss.setY(Yexp); 178 int iter = 0; 179 while(true) { 180 // We need to set the bias here because it is used in the loss calculation of U and W 181 if(this.biasMode) loss.setBias(this.bias); 182 iter += 1; 183 184 final double uLossWeight = etat(iter,eta0_u); 185 final double wLossWeighted = etat(iter,eta0_w); 186 final double weightedLambda_u = lambdat(iter,lambda_u); 187 final double weightedLambda_w = lambdat(iter,lambda_w); 188 // Dprime is tasks x nwords 189 Matrix Dprime = null; 190 if(this.nodataseen){ 191 this.nodataseen = false; 192 Matrix fakeut = new SparseSingleValueInitStrat(1).init(this.u.columnCount(),this.u.rowCount()); 193 Dprime = MatlibMatrixUtils.dotProductTranspose(fakeut, X); // i.e. fakeut . X^T 194 } else { 195 Dprime = MatlibMatrixUtils.dotProductTransposeTranspose(u, X); // i.e. u^T . X^T 196 } 197 198 // ... as is the cost function's X 199 if(zStandardise){ 200// Vector rowMean = CFMatrixUtils.rowMean(Dprime); 201// CFMatrixUtils.minusEqualsCol(Dprime,rowMean); 202 } 203 loss.setX(Dprime); 204 final Matrix neww = updateW(this.w,wLossWeighted, weightedLambda_w); 205 206 // Vprime is nusers x tasks 207 final Matrix Vt = MatlibMatrixUtils.transposeDotProduct(neww,X); // i.e. (X^T.neww)^T X.transpose().times(neww); 208 // ... so the loss function's X is (tasks x nusers) 209 loss.setX(Vt); 210 final Matrix newu = updateU(this.u,uLossWeight, weightedLambda_u); 211 212 final double sumchangew = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(neww, this.w)); 213 final double totalw = MatlibMatrixUtils.normF(this.w); 214 215 final double sumchangeu = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newu, this.u)); 216 final double totalu = MatlibMatrixUtils.normF(this.u); 217 218 double ratioU = 0; 219 if(totalu!=0) ratioU = sumchangeu/totalu; 220 final double ratioW = 0; 221 if(totalw!=0) ratioU = sumchangew/totalw; 222 double ratioB = 0; 223 double ratio = ratioU + ratioW; 224 double totalbias = 0; 225 if(this.biasMode){ 226 Matrix mult = MatlibMatrixUtils.dotProductTransposeTranspose(newu, X); 227 mult = MatlibMatrixUtils.dotProduct(mult, neww); 228 MatlibMatrixUtils.plusInplace(mult, bias); 229 // We must set bias to null! 230 loss.setBias(null); 231 loss.setX(diagX); 232 // Calculate gradient of bias (don't regularise) 233 final Matrix biasGrad = loss.gradient(mult); 234 final double biasLossWeight = biasEtat(iter); 235 final Matrix newbias = updateBias(biasGrad, biasLossWeight); 236 237 final double sumchangebias = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newbias, bias)); 238 totalbias = MatlibMatrixUtils.normF(this.bias); 239 if(totalbias!=0) ratioB = (sumchangebias/totalbias) ; 240 this.bias = newbias; 241 ratio += ratioB; 242 ratio/=3; 243 } 244 else{ 245 ratio/=2; 246 } 247 248 final Double biconvextol = this.params.getTyped("biconvex_tol"); 249 final Integer maxiter = this.params.getTyped("biconvex_maxiter"); 250 if(iter%3 == 0){ 251 logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio)); 252 logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w)); 253 logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u)); 254 logger.debug("Total U magnitude: " + totalu); 255 logger.debug("Total W magnitude: " + totalw); 256 logger.debug("Total Bias: " + totalbias); 257 } 258 if(biconvextol < 0 || ratio < biconvextol || iter >= maxiter) { 259 logger.debug("tolerance reached after iteration: " + iter); 260 logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w)); 261 logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u)); 262 logger.debug("Total U magnitude: " + totalu); 263 logger.debug("Total W magnitude: " + totalw); 264 logger.debug("Total Bias: " + totalbias); 265 break; 266 } 267 } 268 } 269 270 protected Matrix updateBias(Matrix biasGrad, double biasLossWeight) { 271 final Matrix newbias = MatlibMatrixUtils.minus( 272 this.bias, 273 MatlibMatrixUtils.scaleInplace( 274 biasGrad, 275 biasLossWeight 276 ) 277 ); 278 return newbias; 279 } 280 protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) { 281 final Matrix gradW = loss.gradient(currentW); 282 MatlibMatrixUtils.scaleInplace(gradW,wLossWeighted); 283 284 Matrix neww = MatlibMatrixUtils.minus(currentW,gradW); 285 neww = regul.prox(neww, weightedLambda); 286 return neww; 287 } 288 protected Matrix updateU(Matrix currentU, double uLossWeight, double uWeightedLambda) { 289 final Matrix gradU = loss.gradient(currentU); 290 MatlibMatrixUtils.scaleInplace(gradU,uLossWeight); 291 Matrix newu = MatlibMatrixUtils.minus(currentU,gradU); 292 newu = regul.prox(newu, uWeightedLambda); 293 return newu; 294 } 295 private double lambdat(int iter, double lambda) { 296 return lambda/iter; 297 } 298 /** 299 * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal 300 * @param Y 301 * @return the diagonalised Y 302 */ 303 public static SparseMatrix expandY(Matrix Y) { 304 final int ntasks = Y.columnCount(); 305 final SparseMatrix Yexp = SparseMatrix.sparse(ntasks, ntasks); 306 for (int touter = 0; touter < ntasks; touter++) { 307 for (int tinner = 0; tinner < ntasks; tinner++) { 308 if(tinner == touter){ 309 Yexp.put(touter, tinner, Y.get(0, tinner)); 310 } 311 else{ 312 Yexp.put(touter, tinner, Double.NaN); 313 } 314 } 315 } 316 return Yexp; 317 } 318 private double biasEtat(int iter){ 319 final Double biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS); 320 return biasEta0 / Math.sqrt(iter); 321 } 322 323 324 private double etat(int iter,double eta0) { 325 final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS); 326 final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps)); 327 return eta(eta0) / sqrtCeil; 328 } 329 private double eta(double eta0) { 330 return eta0 ; 331 } 332 333 334 335 /** 336 * @return the current apramters 337 */ 338 public BilinearLearnerParameters getParams() { 339 return this.params; 340 } 341 342 /** 343 * @return the current user matrix 344 */ 345 public Matrix getU(){ 346 return this.u; 347 } 348 349 /** 350 * @return the current word matrix 351 */ 352 public Matrix getW(){ 353 return this.w; 354 } 355 /** 356 * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false 357 */ 358 public Matrix getBias() { 359 if(this.biasMode) 360 return this.bias; 361 else 362 return null; 363 } 364 365 /** 366 * Expand the U parameters matrix by added a set of rows. 367 * If currently unset, this function does nothing (assuming U will be initialised in the first round) 368 * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT} 369 * @param newUsers the number of new users to add 370 */ 371 public void addU(int newUsers) { 372 if(this.u == null) return; // If u has not be inited, then it will be on first process 373 final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null); 374 final Matrix newU = ustrat.init(newUsers, this.u.columnCount()); 375 this.u = MatlibMatrixUtils.vstack(this.u,newU); 376 } 377 378 /** 379 * Expand the W parameters matrix by added a set of rows. 380 * If currently unset, this function does nothing (assuming W will be initialised in the first round) 381 * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT} 382 * @param newWords the number of new words to add 383 */ 384 public void addW(int newWords) { 385 if(this.w == null) return; // If w has not be inited, then it will be on first process 386 final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null); 387 final Matrix newW = wstrat.init(newWords, this.w.columnCount()); 388 this.w = MatlibMatrixUtils.vstack(this.w,newW); 389 } 390 391 @Override 392 public MatlibBilinearSparseOnlineLearner clone(){ 393 final MatlibBilinearSparseOnlineLearner ret = new MatlibBilinearSparseOnlineLearner(this.getParams()); 394 ret.u = MatlibMatrixUtils.copy(this.u); 395 ret.w = MatlibMatrixUtils.copy(this.w); 396 if(this.biasMode){ 397 ret.bias = MatlibMatrixUtils.copy(this.bias); 398 } 399 return ret; 400 } 401 /** 402 * @param newu set the model's U 403 */ 404 public void setU(Matrix newu) { 405 this.u = newu; 406 } 407 408 /** 409 * @param neww set the model's W 410 */ 411 public void setW(Matrix neww) { 412 this.w = neww; 413 } 414 @Override 415 public void readBinary(DataInput in) throws IOException { 416 final int nwords = in.readInt(); 417 final int nusers = in.readInt(); 418 final int ntasks = in.readInt(); 419 420 421 this.w = SparseMatrix.sparse(nwords, ntasks); 422 for (int t = 0; t < ntasks; t++) { 423 for (int r = 0; r < nwords; r++) { 424 final double readDouble = in.readDouble(); 425 if(readDouble != 0){ 426 this.w.put(r, t, readDouble); 427 } 428 } 429 } 430 431 this.u = SparseMatrix.sparse(nusers, ntasks); 432 for (int t = 0; t < ntasks; t++) { 433 for (int r = 0; r < nusers; r++) { 434 final double readDouble = in.readDouble(); 435 if(readDouble != 0){ 436 this.u.put(r, t, readDouble); 437 } 438 } 439 } 440 441 this.bias = SparseMatrix.sparse(ntasks, ntasks); 442 for (int t1 = 0; t1 < ntasks; t1++) { 443 for (int t2 = 0; t2 < ntasks; t2++) { 444 final double readDouble = in.readDouble(); 445 if(readDouble != 0){ 446 this.bias.put(t1, t2, readDouble); 447 } 448 } 449 } 450 } 451 @Override 452 public byte[] binaryHeader() { 453 return "".getBytes(); 454 } 455 @Override 456 public void writeBinary(DataOutput out) throws IOException { 457 out.writeInt(w.rowCount()); 458 out.writeInt(u.rowCount()); 459 out.writeInt(u.columnCount()); 460 final double[] wdata = w.asColumnMajorArray(); 461 for (int i = 0; i < wdata.length; i++) { 462 out.writeDouble(wdata[i]); 463 } 464 final double[] udata = u.asColumnMajorArray(); 465 for (int i = 0; i < udata.length; i++) { 466 out.writeDouble(udata[i]); 467 } 468 final double[] biasdata = bias.asColumnMajorArray(); 469 for (int i = 0; i < biasdata.length; i++) { 470 out.writeDouble(biasdata[i]); 471 } 472 } 473 474 475 @Override 476 public Matrix predict(Matrix x) { 477 Matrix xt = MatlibMatrixUtils.transpose(x); 478 final Matrix mult = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.transpose(u), xt),this.w); 479 if(this.biasMode) MatlibMatrixUtils.plusInplace(mult,this.bias); 480 Matrix ydiag = new DiagonalMatrix(mult); 481 return ydiag; 482 } 483}