001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.Vector; 034import gov.sandia.cognition.math.matrix.mtj.AbstractSparseMatrix; 035import gov.sandia.cognition.math.matrix.mtj.SparseMatrix; 036import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ; 037 038import java.io.DataInput; 039import java.io.DataOutput; 040import java.io.IOException; 041 042import org.apache.logging.log4j.Logger; 043import org.apache.logging.log4j.LogManager; 044 045import org.openimaj.io.ReadWriteableBinary; 046import org.openimaj.math.matrix.CFMatrixUtils; 047import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy; 048import org.openimaj.ml.linear.learner.init.InitStrategy; 049import org.openimaj.ml.linear.learner.init.SparseSingleValueInitStrat; 050import org.openimaj.ml.linear.learner.loss.LossFunction; 051import org.openimaj.ml.linear.learner.loss.MatLossFunction; 052import org.openimaj.ml.linear.learner.regul.Regulariser; 053 054 055/** 056 * An implementation of a stochastic gradient decent with proximal perameter adjustment 057 * (for regularised parameters). 058 * 059 * Data is dealt with sequentially using a one pass implementation of the 060 * online proximal algorithm described in chapter 9 and 10 of: 061 * The Geometry of Constrained Structured Prediction: Applications to Inference and 062 * Learning of Natural Language Syntax, PhD, Andre T. Martins 063 * 064 * The implementation does the following: 065 * - When an X,Y is recieved: 066 * - Update currently held batch 067 * - If the batch is full: 068 * - While There is a great deal of change in U and W: 069 * - Calculate the gradient of W holding U fixed 070 * - Proximal update of W 071 * - Calculate the gradient of U holding W fixed 072 * - Proximal update of U 073 * - Calculate the gradient of Bias holding U and W fixed 074 * - flush the batch 075 * - return current U and W (same as last time is batch isn't filled yet) 076 * 077 * 078 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 079 * 080 */ 081public class BilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{ 082 083 static Logger logger = LogManager.getLogger(BilinearSparseOnlineLearner.class); 084 085 protected BilinearLearnerParameters params; 086 protected Matrix w; 087 protected Matrix u; 088 protected SparseMatrixFactoryMTJ smf = SparseMatrixFactoryMTJ.INSTANCE; 089 protected LossFunction loss; 090 protected Regulariser regul; 091 protected Double lambda_w,lambda_u; 092 protected Boolean biasMode; 093 protected Matrix bias; 094 protected Matrix diagX; 095 protected Double eta0_u; 096 protected Double eta0_w; 097 098 private Boolean forceSparcity; 099 100 private Boolean zStandardise; 101 102 private boolean nodataseen; 103 104 private double eta_gamma; 105 106 private double biasEta0; 107 108 /** 109 * The default parameters. These won't work with your dataset, i promise. 110 */ 111 public BilinearSparseOnlineLearner() { 112 this(new BilinearLearnerParameters()); 113 } 114 /** 115 * @param params the parameters used by this learner 116 */ 117 public BilinearSparseOnlineLearner(BilinearLearnerParameters params) { 118 this.params = params; 119 reinitParams(); 120 } 121 122 /** 123 * must be called if any parameters are changed 124 */ 125 public void reinitParams() { 126 this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS); 127 this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL); 128 this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W); 129 this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U); 130 this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS); 131 this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U); 132 this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W); 133 this.biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS); 134 this.eta_gamma = params.getTyped(BilinearLearnerParameters.ETA_GAMMA); 135 this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY); 136 this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE); 137 if(!this.loss.isMatrixLoss()) 138 this.loss = new MatLossFunction(this.loss); 139 this.nodataseen = true; 140 } 141 private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) { 142 final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y); 143 final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y); 144 this.w = wstrat.init(xrows, ycols); 145 this.u = ustrat.init(xcols, ycols); 146 if(this.forceSparcity) 147 { 148 this.u = CFMatrixUtils.asSparseColumn(this.u); 149 this.w = CFMatrixUtils.asSparseColumn(this.w); 150 } 151 152 this.bias = smf.createMatrix(ycols,ycols); 153 if(this.biasMode){ 154 final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y); 155 this.bias = bstrat.init(ycols, ycols); 156 this.diagX = smf.createIdentity(ycols, ycols); 157 } 158 } 159 160 161 162 private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) { 163 final InitStrategy strat = this.params.getTyped(initstrat); 164 if(strat instanceof ContextAwareInitStrategy){ 165 final ContextAwareInitStrategy<Matrix, Matrix> cwStrat = this.params.getTyped(initstrat); 166 cwStrat.setLearner(this); 167 cwStrat.setContext(x, y); 168 return cwStrat; 169 } 170 return strat; 171 } 172 @Override 173 public void process(Matrix X, Matrix Y){ 174 prepareNextRound(X, Y); 175 int iter = 0; 176 Matrix xt = X.transpose(); 177 Matrix xtrows = xt; 178 if(xt instanceof AbstractSparseMatrix){ 179 xtrows = CFMatrixUtils.asSparseRow(xt); 180 } 181 while(true) { 182 // We need to set the bias here because it is used in the loss calculation of U and W 183 if(this.biasMode) loss.setBias(this.bias); 184 iter += 1; 185 186 // Perform the bilinear operation 187 final Matrix neww = updateW(xt,eta0_w, lambda_u); 188 final Matrix newu = updateU(xtrows,neww,eta0_u, lambda_w); 189 Matrix newbias = null; 190 if(this.biasMode){ 191 newbias = updateBias(xt, newu, neww, biasEta0); 192 } 193 194 // This part of the code checks if we can stop the bilinear steps by checking how much everything has changed proportionally 195 196 double ratioB = 0; 197 double totalbias = 0; 198 199 final double sumchangew = CFMatrixUtils.absSum(neww.minus(this.w)); 200 final double totalw = CFMatrixUtils.absSum(this.w); 201 202 final double sumchangeu = CFMatrixUtils.absSum(newu.minus(this.u)); 203 final double totalu = CFMatrixUtils.absSum(this.u); 204 205 double ratioU = 0; 206 if(totalu!=0) ratioU = sumchangeu/totalu; 207 final double ratioW = 0; 208 if(totalw!=0) ratioU = sumchangew/totalw; 209 double ratio = ratioU + ratioW; 210 if(this.biasMode){ 211 final double sumchangebias = CFMatrixUtils.absSum(newbias.minus(this.bias)); 212 totalbias = CFMatrixUtils.absSum(this.bias); 213 if(totalbias!=0) ratioB = (sumchangebias/totalbias) ; 214 ratio += ratioB; 215 ratio/=3; 216 } else { 217 ratio/=2; 218 } 219 220 /** 221 * This is not a matter of simply type 222 * The 0 values of the sparse matrix are also removed. very important. 223 */ 224 if(this.forceSparcity) 225 { 226 this.u = CFMatrixUtils.asSparseColumn(newu); 227 this.w = CFMatrixUtils.asSparseColumn(neww); 228 } 229 else{ 230 231 this.w = neww; 232 this.u = newu; 233 } 234 235 if(this.biasMode){ 236 this.bias = newbias; 237 } 238 239 final Double biconvextol = this.params.getTyped("biconvex_tol"); 240 final Integer maxiter = this.params.getTyped("biconvex_maxiter"); 241 if(iter%3 == 0){ 242 logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio)); 243 logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w)); 244 logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u)); 245 logger.debug("Total U magnitude: " + totalu); 246 logger.debug("Total W magnitude: " + totalw); 247 logger.debug("Total Bias: " + totalbias); 248 } 249 if(biconvextol < 0 || ratio < biconvextol || iter >= maxiter) { 250 logger.debug("tolerance reached after iteration: " + iter); 251 logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w)); 252 logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u)); 253 logger.debug("Total U magnitude: " + totalu); 254 logger.debug("Total W magnitude: " + totalw); 255 logger.debug("Total Bias: " + totalbias); 256 break; 257 } 258 } 259 } 260 private void prepareNextRound(Matrix X, Matrix Y) { 261 final int nfeatures = X.getNumRows(); 262 final int nusers = X.getNumColumns(); 263 final int ntasks = Y.getNumColumns(); 264// int ninstances = Y.getNumRows(); // Assume 1 instance! 265 266 // only inits when the current params is null 267 if (this.w == null){ 268 initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks 269 } 270 271 final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING); 272 final double weighting = 1.0 - dampening ; 273 274 logger.debug("... dampening w, u and bias by: " + weighting); 275 276 // Adjust for weighting 277 this.w.scaleEquals(weighting); 278 this.u.scaleEquals(weighting); 279 if(this.biasMode){ 280 this.bias.scaleEquals(weighting); 281 } 282 // First expand Y s.t. blocks of rows contain the task values for each row of Y. 283 // This means Yexp has (n * t x t) 284 final SparseMatrix Yexp = expandY(Y); 285 loss.setY(Yexp); 286 } 287 288 protected Matrix updateBias(Matrix xt, Matrix nu, Matrix nw, double biasLossWeight) { 289 Matrix newut = nu.transpose(); 290 Matrix utxt = CFMatrixUtils.fastdot(newut,xt); 291 Matrix utxtw = CFMatrixUtils.fastdot(utxt,nw); 292 final Matrix mult = utxtw.plus(this.bias); 293 // We must set bias to null! 294 loss.setBias(null); 295 loss.setX(diagX); 296 // Calculate gradient of bias (don't regularise) 297 final Matrix biasGrad = loss.gradient(mult); 298 Matrix newbias = null; 299 for (int i = 0; i < 1000; i++) { 300 logger.debug("... Line searching etab = " + biasLossWeight); 301 newbias = this.bias.clone(); 302 Matrix scaledGradW = biasGrad.scale(1./biasLossWeight); 303 newbias = CFMatrixUtils.fastminus(newbias,scaledGradW); 304 305 if(loss.test_backtrack(this.bias, biasGrad, newbias, biasLossWeight)) 306 break; 307 biasLossWeight *= eta_gamma; 308 } 309// final Matrix newbias = this.bias.minus( 310// CFMatrixUtils.timesInplace( 311// biasGrad, 312// biasLossWeight 313// ) 314// ); 315 return newbias; 316 } 317 protected Matrix updateW(Matrix xt, double wLossWeighted, double weightedLambda) { 318 // Dprime is tasks x nwords 319 320 Matrix Dprime = null; 321 Matrix ut = this.u.transpose(); 322 if(this.nodataseen){ 323 this.nodataseen = false; 324 Matrix fakeu = new SparseSingleValueInitStrat(1).init(this.u.getNumColumns(), this.u.getNumRows()); 325 Dprime = CFMatrixUtils.fastdot(fakeu,xt); 326 } else { 327 Dprime = CFMatrixUtils.fastdot(ut, xt); 328 } 329 330 // ... as is the cost function's X 331 if(zStandardise){ 332 Vector rowMean = CFMatrixUtils.rowMean(Dprime); 333 CFMatrixUtils.minusEqualsCol(Dprime,rowMean); 334 } 335 loss.setX(Dprime); 336 final Matrix gradW = loss.gradient(this.w); 337 logger.debug("Abs w_grad: " + CFMatrixUtils.absSum(gradW)); 338 Matrix neww = null; 339 for (int i = 0; i < 1000; i++) { 340 logger.debug("... Line searching etaw = " + wLossWeighted); 341 neww = this.w.clone(); 342 Matrix scaledGradW = gradW.scale(1./wLossWeighted); 343 neww = CFMatrixUtils.fastminus(neww,scaledGradW); 344 neww = regul.prox(neww, weightedLambda/wLossWeighted); 345 if(loss.test_backtrack(this.w, gradW, neww, wLossWeighted)) 346 break; 347 wLossWeighted *= eta_gamma; 348 } 349 350 return neww; 351 } 352 protected Matrix updateU(Matrix xtrows, Matrix neww, double uLossWeight, double uWeightedLambda) { 353 // Vprime is nusers x tasks 354 final Matrix Vprime = CFMatrixUtils.fastdot(xtrows,neww); 355 // ... so the loss function's X is (tasks x nusers) 356 Matrix Vt = CFMatrixUtils.asSparseRow(Vprime.transpose()); 357 if(zStandardise){ 358 Vector rowMean = CFMatrixUtils.rowMean(Vt); 359 CFMatrixUtils.minusEqualsCol(Vt,rowMean); 360 } 361 loss.setX(Vt); 362 final Matrix gradU = loss.gradient(this.u); 363 logger.debug("Abs u_grad: " + CFMatrixUtils.absSum(gradU)); 364// CFMatrixUtils.timesInplace(gradU,uLossWeight); 365// newu = regul.prox(newu, uWeightedLambda); 366 Matrix newu = null; 367 for (int i = 0; i < 1000; i++) { 368 logger.debug("... Line searching etau = " + uLossWeight); 369 newu = this.u.clone(); 370 Matrix scaledGradW = gradU.scale(1./uLossWeight); 371 newu = CFMatrixUtils.fastminus(newu,scaledGradW); 372 newu = regul.prox(newu, uWeightedLambda/uLossWeight); 373 if(loss.test_backtrack(this.u, gradU, newu, uLossWeight)) 374 break; 375 uLossWeight *= eta_gamma; 376 } 377 378 return newu; 379 } 380 private double lambdat(int iter, double lambda) { 381 return lambda/iter; 382 } 383 /** 384 * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal 385 * @param Y 386 * @return the diagonalised Y 387 */ 388 public static SparseMatrix expandY(Matrix Y) { 389 final int ntasks = Y.getNumColumns(); 390 final SparseMatrix Yexp = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks); 391 for (int touter = 0; touter < ntasks; touter++) { 392 for (int tinner = 0; tinner < ntasks; tinner++) { 393 if(tinner == touter){ 394 Yexp.setElement(touter, tinner, Y.getElement(0, tinner)); 395 } 396 else{ 397 Yexp.setElement(touter, tinner, Double.NaN); 398 } 399 } 400 } 401 return Yexp; 402 } 403 404 405 protected double etat(int iter,double eta0) { 406 final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS); 407 final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps)); 408 return eta(eta0) / sqrtCeil; 409 } 410 private double eta(double eta0) { 411 return eta0 ; 412 } 413 414 415 416 /** 417 * @return the current apramters 418 */ 419 public BilinearLearnerParameters getParams() { 420 return this.params; 421 } 422 423 /** 424 * @return the current user matrix 425 */ 426 public Matrix getU(){ 427 return this.u; 428 } 429 430 /** 431 * @return the current word matrix 432 */ 433 public Matrix getW(){ 434 return this.w; 435 } 436 /** 437 * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false 438 */ 439 public Matrix getBias() { 440 if(this.biasMode) 441 return this.bias; 442 else 443 return null; 444 } 445 446 /** 447 * Expand the U parameters matrix by added a set of rows. 448 * If currently unset, this function does nothing (assuming U will be initialised in the first round) 449 * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT} 450 * @param newUsers the number of new users to add 451 */ 452 public void addU(int newUsers) { 453 if(this.u == null) return; // If u has not be inited, then it will be on first process 454 final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null); 455 final Matrix newU = ustrat.init(newUsers, this.u.getNumColumns()); 456 this.u = CFMatrixUtils.vstack(this.u,newU); 457 } 458 459 /** 460 * Expand the W parameters matrix by added a set of rows. 461 * If currently unset, this function does nothing (assuming W will be initialised in the first round) 462 * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT} 463 * @param newWords the number of new words to add 464 */ 465 public void addW(int newWords) { 466 if(this.w == null) return; // If w has not be inited, then it will be on first process 467 final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null); 468 final Matrix newW = wstrat.init(newWords, this.w.getNumColumns()); 469 this.w = CFMatrixUtils.vstack(this.w,newW); 470 } 471 472 @Override 473 public BilinearSparseOnlineLearner clone(){ 474 final BilinearSparseOnlineLearner ret = new BilinearSparseOnlineLearner(this.getParams()); 475 ret.u = this.u.clone(); 476 ret.w = this.w.clone(); 477 if(this.biasMode){ 478 ret.bias = this.bias.clone(); 479 } 480 return ret; 481 } 482 /** 483 * @param newu set the model's U 484 */ 485 public void setU(Matrix newu) { 486 this.u = newu; 487 } 488 489 /** 490 * @param neww set the model's W 491 */ 492 public void setW(Matrix neww) { 493 this.w = neww; 494 } 495 @Override 496 public void readBinary(DataInput in) throws IOException { 497 final int nwords = in.readInt(); 498 final int nusers = in.readInt(); 499 final int ntasks = in.readInt(); 500 501 502 this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, ntasks); 503 for (int t = 0; t < ntasks; t++) { 504 for (int r = 0; r < nwords; r++) { 505 final double readDouble = in.readDouble(); 506 if(readDouble != 0){ 507 this.w.setElement(r, t, readDouble); 508 } 509 } 510 } 511 512 this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, ntasks); 513 for (int t = 0; t < ntasks; t++) { 514 for (int r = 0; r < nusers; r++) { 515 final double readDouble = in.readDouble(); 516 if(readDouble != 0){ 517 this.u.setElement(r, t, readDouble); 518 } 519 } 520 } 521 522 this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks); 523 for (int t1 = 0; t1 < ntasks; t1++) { 524 for (int t2 = 0; t2 < ntasks; t2++) { 525 final double readDouble = in.readDouble(); 526 if(readDouble != 0){ 527 this.bias.setElement(t1, t2, readDouble); 528 } 529 } 530 } 531 } 532 @Override 533 public byte[] binaryHeader() { 534 return "".getBytes(); 535 } 536 @Override 537 public void writeBinary(DataOutput out) throws IOException { 538 out.writeInt(w.getNumRows()); 539 out.writeInt(u.getNumRows()); 540 out.writeInt(u.getNumColumns()); 541 final double[] wdata = CFMatrixUtils.getData(w); 542 for (int i = 0; i < wdata.length; i++) { 543 out.writeDouble(wdata[i]); 544 } 545 final double[] udata = CFMatrixUtils.getData(u); 546 for (int i = 0; i < udata.length; i++) { 547 out.writeDouble(udata[i]); 548 } 549 final double[] biasdata = CFMatrixUtils.getData(bias); 550 for (int i = 0; i < biasdata.length; i++) { 551 out.writeDouble(biasdata[i]); 552 } 553 } 554 555 556 @Override 557 public Matrix predict(Matrix x) { 558 final Matrix mult = this.u.transpose().times(x.transpose()).times(this.w); 559 if(this.biasMode)mult.plusEquals(this.bias); 560 final Vector ydiag = CFMatrixUtils.diag(mult); 561 final Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, ydiag.getDimensionality()); 562 createIdentity.setRow(0, ydiag); 563 return createIdentity; 564 } 565}