001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.Vector;
034import gov.sandia.cognition.math.matrix.mtj.AbstractSparseMatrix;
035import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
036import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
037
038import java.io.DataInput;
039import java.io.DataOutput;
040import java.io.IOException;
041
042import org.apache.logging.log4j.Logger;
043import org.apache.logging.log4j.LogManager;
044
045import org.openimaj.io.ReadWriteableBinary;
046import org.openimaj.math.matrix.CFMatrixUtils;
047import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy;
048import org.openimaj.ml.linear.learner.init.InitStrategy;
049import org.openimaj.ml.linear.learner.init.SparseSingleValueInitStrat;
050import org.openimaj.ml.linear.learner.loss.LossFunction;
051import org.openimaj.ml.linear.learner.loss.MatLossFunction;
052import org.openimaj.ml.linear.learner.regul.Regulariser;
053
054
055/**
056 * An implementation of a stochastic gradient decent with proximal perameter adjustment
057 * (for regularised parameters).
058 *
059 * Data is dealt with sequentially using a one pass implementation of the
060 * online proximal algorithm described in chapter 9 and 10 of:
061 * The Geometry of Constrained Structured Prediction: Applications to Inference and
062 * Learning of Natural Language Syntax, PhD, Andre T. Martins
063 *
064 * The implementation does the following:
065 *      - When an X,Y is recieved:
066 *              - Update currently held batch
067 *              - If the batch is full:
068 *                      - While There is a great deal of change in U and W:
069 *                              - Calculate the gradient of W holding U fixed
070 *                              - Proximal update of W
071 *                              - Calculate the gradient of U holding W fixed
072 *                              - Proximal update of U
073 *                              - Calculate the gradient of Bias holding U and W fixed
074 *                      - flush the batch
075 *              - return current U and W (same as last time is batch isn't filled yet)
076 *
077 *
078 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
079 *
080 */
081public class BilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{
082
083        static Logger logger = LogManager.getLogger(BilinearSparseOnlineLearner.class);
084
085        protected BilinearLearnerParameters params;
086        protected Matrix w;
087        protected Matrix u;
088        protected SparseMatrixFactoryMTJ smf = SparseMatrixFactoryMTJ.INSTANCE;
089        protected LossFunction loss;
090        protected Regulariser regul;
091        protected Double lambda_w,lambda_u;
092        protected Boolean biasMode;
093        protected Matrix bias;
094        protected Matrix diagX;
095        protected Double eta0_u;
096        protected Double eta0_w;
097
098        private Boolean forceSparcity;
099
100        private Boolean zStandardise;
101
102        private boolean nodataseen;
103
104        private double eta_gamma;
105
106        private double biasEta0;
107
108        /**
109         * The default parameters. These won't work with your dataset, i promise.
110         */
111        public BilinearSparseOnlineLearner() {
112                this(new BilinearLearnerParameters());
113        }
114        /**
115         * @param params the parameters used by this learner
116         */
117        public BilinearSparseOnlineLearner(BilinearLearnerParameters params) {
118                this.params = params;
119                reinitParams();
120        }
121
122        /**
123         * must be called if any parameters are changed
124         */
125        public void reinitParams() {
126                this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS);
127                this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL);
128                this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
129                this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
130                this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS);
131                this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U);
132                this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W);
133                this.biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS);
134                this.eta_gamma = params.getTyped(BilinearLearnerParameters.ETA_GAMMA);
135                this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
136                this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE);
137                if(!this.loss.isMatrixLoss())
138                        this.loss = new MatLossFunction(this.loss);
139                this.nodataseen = true;
140        }
141        private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
142                final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y);
143                final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y);
144                this.w = wstrat.init(xrows, ycols);
145                this.u = ustrat.init(xcols, ycols);
146                if(this.forceSparcity)
147                {
148                        this.u = CFMatrixUtils.asSparseColumn(this.u);
149                        this.w = CFMatrixUtils.asSparseColumn(this.w);
150                }
151
152                this.bias = smf.createMatrix(ycols,ycols);
153                if(this.biasMode){
154                        final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y);
155                        this.bias = bstrat.init(ycols, ycols);
156                        this.diagX = smf.createIdentity(ycols, ycols);
157                }
158        }
159        
160        
161
162        private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
163                final InitStrategy strat = this.params.getTyped(initstrat);
164                if(strat instanceof ContextAwareInitStrategy){
165                        final ContextAwareInitStrategy<Matrix, Matrix> cwStrat = this.params.getTyped(initstrat);
166                        cwStrat.setLearner(this);
167                        cwStrat.setContext(x, y);
168                        return cwStrat;
169                }
170                return strat;
171        }
172        @Override
173        public void process(Matrix X, Matrix Y){
174                prepareNextRound(X, Y);
175                int iter = 0;
176                Matrix xt = X.transpose();
177                Matrix xtrows = xt;
178                if(xt instanceof AbstractSparseMatrix){
179                        xtrows = CFMatrixUtils.asSparseRow(xt);
180                }
181                while(true) {
182                        // We need to set the bias here because it is used in the loss calculation of U and W
183                        if(this.biasMode) loss.setBias(this.bias);
184                        iter += 1;
185
186                        // Perform the bilinear operation
187                        final Matrix neww = updateW(xt,eta0_w, lambda_u);
188                        final Matrix newu = updateU(xtrows,neww,eta0_u, lambda_w);
189                        Matrix newbias = null;
190                        if(this.biasMode){
191                                newbias = updateBias(xt, newu, neww, biasEta0);
192                        }
193                        
194                        // This part of the code checks if we can stop the bilinear steps by checking how much everything has changed proportionally
195                        
196                        double ratioB = 0;
197                        double totalbias = 0;
198                        
199                        final double sumchangew = CFMatrixUtils.absSum(neww.minus(this.w));
200                        final double totalw = CFMatrixUtils.absSum(this.w);
201
202                        final double sumchangeu = CFMatrixUtils.absSum(newu.minus(this.u));
203                        final double totalu = CFMatrixUtils.absSum(this.u);
204
205                        double ratioU = 0;
206                        if(totalu!=0) ratioU = sumchangeu/totalu;
207                        final double ratioW = 0;
208                        if(totalw!=0) ratioU = sumchangew/totalw;
209                        double ratio = ratioU + ratioW;
210                        if(this.biasMode){
211                                final double sumchangebias = CFMatrixUtils.absSum(newbias.minus(this.bias));
212                                totalbias = CFMatrixUtils.absSum(this.bias);
213                                if(totalbias!=0) ratioB = (sumchangebias/totalbias) ;
214                                ratio += ratioB;
215                                ratio/=3;
216                        } else {
217                                ratio/=2;
218                        }
219                        
220                        /**
221                         * This is not a matter of simply type
222                         * The 0 values of the sparse matrix are also removed. very important.
223                         */
224                        if(this.forceSparcity)
225                        {
226                                this.u = CFMatrixUtils.asSparseColumn(newu);
227                                this.w = CFMatrixUtils.asSparseColumn(neww);
228                        }
229                        else{
230
231                                this.w = neww;
232                                this.u = newu;
233                        }
234                        
235                        if(this.biasMode){
236                                this.bias = newbias;
237                        }
238
239                        final Double biconvextol = this.params.getTyped("biconvex_tol");
240                        final Integer maxiter = this.params.getTyped("biconvex_maxiter");
241                        if(iter%3 == 0){
242                                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio));
243                                logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w));
244                                logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u));
245                                logger.debug("Total U magnitude: " + totalu);
246                                logger.debug("Total W magnitude: " + totalw);
247                                logger.debug("Total Bias: " + totalbias);
248                        }
249                        if(biconvextol  < 0 || ratio < biconvextol || iter >= maxiter) {
250                                logger.debug("tolerance reached after iteration: " + iter);
251                                logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w));
252                                logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u));
253                                logger.debug("Total U magnitude: " + totalu);
254                                logger.debug("Total W magnitude: " + totalw);
255                                logger.debug("Total Bias: " + totalbias);
256                                break;
257                        }
258                }
259        }
260        private void prepareNextRound(Matrix X, Matrix Y) {
261                final int nfeatures = X.getNumRows();
262                final int nusers = X.getNumColumns();
263                final int ntasks = Y.getNumColumns();
264//              int ninstances = Y.getNumRows(); // Assume 1 instance!
265                
266                // only inits when the current params is null
267                if (this.w == null){
268                        initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks
269                }
270
271                final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING);
272                final double weighting = 1.0 - dampening ;
273
274                logger.debug("... dampening w, u and bias by: " + weighting);
275
276                // Adjust for weighting
277                this.w.scaleEquals(weighting);
278                this.u.scaleEquals(weighting);
279                if(this.biasMode){
280                        this.bias.scaleEquals(weighting);
281                }
282                // First expand Y s.t. blocks of rows contain the task values for each row of Y.
283                // This means Yexp has (n * t x t)
284                final SparseMatrix Yexp = expandY(Y);
285                loss.setY(Yexp);
286        }
287        
288        protected Matrix updateBias(Matrix xt, Matrix nu, Matrix nw, double biasLossWeight) {
289                Matrix newut = nu.transpose();
290                Matrix utxt = CFMatrixUtils.fastdot(newut,xt);
291                Matrix utxtw = CFMatrixUtils.fastdot(utxt,nw);
292                final Matrix mult = utxtw.plus(this.bias);
293                // We must set bias to null!
294                loss.setBias(null);
295                loss.setX(diagX);
296                // Calculate gradient of bias (don't regularise)
297                final Matrix biasGrad = loss.gradient(mult);
298                Matrix newbias = null;
299                for (int i = 0; i < 1000; i++) {
300                        logger.debug("... Line searching etab = " + biasLossWeight);
301                        newbias = this.bias.clone();
302                        Matrix scaledGradW = biasGrad.scale(1./biasLossWeight);
303                        newbias = CFMatrixUtils.fastminus(newbias,scaledGradW);
304                        
305                        if(loss.test_backtrack(this.bias, biasGrad, newbias, biasLossWeight)) 
306                                break;
307                        biasLossWeight *= eta_gamma;
308                }
309//              final Matrix newbias = this.bias.minus(
310//                              CFMatrixUtils.timesInplace(
311//                                              biasGrad,
312//                                              biasLossWeight
313//                              )
314//              );
315                return newbias;
316        }
317        protected Matrix updateW(Matrix xt, double wLossWeighted, double weightedLambda) {
318                // Dprime is tasks x nwords
319                
320                Matrix Dprime = null;
321                Matrix ut = this.u.transpose();                         
322                if(this.nodataseen){
323                        this.nodataseen = false;
324                        Matrix fakeu = new SparseSingleValueInitStrat(1).init(this.u.getNumColumns(), this.u.getNumRows());
325                        Dprime = CFMatrixUtils.fastdot(fakeu,xt);
326                } else {
327                        Dprime = CFMatrixUtils.fastdot(ut, xt);
328                }
329                
330                // ... as is the cost function's X
331                if(zStandardise){
332                        Vector rowMean = CFMatrixUtils.rowMean(Dprime);
333                        CFMatrixUtils.minusEqualsCol(Dprime,rowMean);
334                }
335                loss.setX(Dprime);
336                final Matrix gradW = loss.gradient(this.w);
337                logger.debug("Abs w_grad: " + CFMatrixUtils.absSum(gradW));
338                Matrix neww = null;
339                for (int i = 0; i < 1000; i++) {
340                        logger.debug("... Line searching etaw = " + wLossWeighted);
341                        neww = this.w.clone();
342                        Matrix scaledGradW = gradW.scale(1./wLossWeighted);
343                        neww = CFMatrixUtils.fastminus(neww,scaledGradW);
344                        neww = regul.prox(neww, weightedLambda/wLossWeighted);
345                        if(loss.test_backtrack(this.w, gradW, neww, wLossWeighted)) 
346                                break;
347                        wLossWeighted *= eta_gamma;
348                }
349                
350                return neww;
351        }
352        protected Matrix updateU(Matrix xtrows, Matrix neww, double uLossWeight, double uWeightedLambda) {
353                // Vprime is nusers x tasks
354                final Matrix Vprime = CFMatrixUtils.fastdot(xtrows,neww);
355                // ... so the loss function's X is (tasks x nusers)
356                Matrix Vt = CFMatrixUtils.asSparseRow(Vprime.transpose());
357                if(zStandardise){
358                        Vector rowMean = CFMatrixUtils.rowMean(Vt);
359                        CFMatrixUtils.minusEqualsCol(Vt,rowMean);
360                }
361                loss.setX(Vt);
362                final Matrix gradU = loss.gradient(this.u);
363                logger.debug("Abs u_grad: " + CFMatrixUtils.absSum(gradU));
364//              CFMatrixUtils.timesInplace(gradU,uLossWeight);
365//              newu = regul.prox(newu, uWeightedLambda);
366                Matrix newu = null;
367                for (int i = 0; i < 1000; i++) {
368                        logger.debug("... Line searching etau = " + uLossWeight);
369                        newu = this.u.clone();
370                        Matrix scaledGradW = gradU.scale(1./uLossWeight);
371                        newu = CFMatrixUtils.fastminus(newu,scaledGradW);
372                        newu = regul.prox(newu, uWeightedLambda/uLossWeight);
373                        if(loss.test_backtrack(this.u, gradU, newu, uLossWeight)) 
374                                break;
375                        uLossWeight *= eta_gamma;
376                }
377                
378                return newu;
379        }
380        private double lambdat(int iter, double lambda) {
381                return lambda/iter;
382        }
383        /**
384         * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal
385         * @param Y
386         * @return the diagonalised Y
387         */
388        public static SparseMatrix expandY(Matrix Y) {
389                final int ntasks = Y.getNumColumns();
390                final SparseMatrix Yexp = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
391                for (int touter = 0; touter < ntasks; touter++) {
392                        for (int tinner = 0; tinner < ntasks; tinner++) {
393                                if(tinner == touter){
394                                        Yexp.setElement(touter, tinner, Y.getElement(0, tinner));
395                                }
396                                else{
397                                        Yexp.setElement(touter, tinner, Double.NaN);
398                                }
399                        }
400                }
401                return Yexp;
402        }
403
404
405        protected double etat(int iter,double eta0) {
406                final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS);
407                final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps));
408                return eta(eta0) / sqrtCeil;
409        }
410        private double eta(double eta0) {
411                return eta0 ;
412        }
413
414
415
416        /**
417         * @return the current apramters
418         */
419        public BilinearLearnerParameters getParams() {
420                return this.params;
421        }
422
423        /**
424         * @return the current user matrix
425         */
426        public Matrix getU(){
427                return this.u;
428        }
429
430        /**
431         * @return the current word matrix
432         */
433        public Matrix getW(){
434                return this.w;
435        }
436        /**
437         * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false
438         */
439        public Matrix getBias() {
440                if(this.biasMode)
441                        return this.bias;
442                else
443                        return null;
444        }
445
446        /**
447         * Expand the U parameters matrix by added a set of rows.
448         * If currently unset, this function does nothing (assuming U will be initialised in the first round)
449         * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT}
450         * @param newUsers the number of new users to add
451         */
452        public void addU(int newUsers) {
453                if(this.u == null) return; // If u has not be inited, then it will be on first process
454                final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null);
455                final Matrix newU = ustrat.init(newUsers, this.u.getNumColumns());
456                this.u = CFMatrixUtils.vstack(this.u,newU);
457        }
458
459        /**
460         * Expand the W parameters matrix by added a set of rows.
461         * If currently unset, this function does nothing (assuming W will be initialised in the first round)
462         * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT}
463         * @param newWords the number of new words to add
464         */
465        public void addW(int newWords) {
466                if(this.w == null) return; // If w has not be inited, then it will be on first process
467                final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null);
468                final Matrix newW = wstrat.init(newWords, this.w.getNumColumns());
469                this.w = CFMatrixUtils.vstack(this.w,newW);
470        }
471
472        @Override
473        public BilinearSparseOnlineLearner clone(){
474                final BilinearSparseOnlineLearner ret = new BilinearSparseOnlineLearner(this.getParams());
475                ret.u = this.u.clone();
476                ret.w = this.w.clone();
477                if(this.biasMode){
478                        ret.bias = this.bias.clone();
479                }
480                return ret;
481        }
482        /**
483         * @param newu set the model's U
484         */
485        public void setU(Matrix newu) {
486                this.u = newu;
487        }
488
489        /**
490         * @param neww set the model's W
491         */
492        public void setW(Matrix neww) {
493                this.w = neww;
494        }
495        @Override
496        public void readBinary(DataInput in) throws IOException {
497                final int nwords = in.readInt();
498                final int nusers = in.readInt();
499                final int ntasks = in.readInt();
500
501
502                this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, ntasks);
503                for (int t = 0; t < ntasks; t++) {
504                        for (int r = 0; r < nwords; r++) {
505                                final double readDouble = in.readDouble();
506                                if(readDouble != 0){
507                                        this.w.setElement(r, t, readDouble);
508                                }
509                        }
510                }
511
512                this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, ntasks);
513                for (int t = 0; t < ntasks; t++) {
514                        for (int r = 0; r < nusers; r++) {
515                                final double readDouble = in.readDouble();
516                                if(readDouble != 0){
517                                        this.u.setElement(r, t, readDouble);
518                                }
519                        }
520                }
521
522                this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
523                for (int t1 = 0; t1 < ntasks; t1++) {
524                        for (int t2 = 0; t2 < ntasks; t2++) {
525                                final double readDouble = in.readDouble();
526                                if(readDouble != 0){
527                                        this.bias.setElement(t1, t2, readDouble);
528                                }
529                        }
530                }
531        }
532        @Override
533        public byte[] binaryHeader() {
534                return "".getBytes();
535        }
536        @Override
537        public void writeBinary(DataOutput out) throws IOException {
538                out.writeInt(w.getNumRows());
539                out.writeInt(u.getNumRows());
540                out.writeInt(u.getNumColumns());
541                final double[] wdata = CFMatrixUtils.getData(w);
542                for (int i = 0; i < wdata.length; i++) {
543                        out.writeDouble(wdata[i]);
544                }
545                final double[] udata = CFMatrixUtils.getData(u);
546                for (int i = 0; i < udata.length; i++) {
547                        out.writeDouble(udata[i]);
548                }
549                final double[] biasdata = CFMatrixUtils.getData(bias);
550                for (int i = 0; i < biasdata.length; i++) {
551                        out.writeDouble(biasdata[i]);
552                }
553        }
554
555
556        @Override
557        public Matrix predict(Matrix x) {
558                final Matrix mult = this.u.transpose().times(x.transpose()).times(this.w);
559                if(this.biasMode)mult.plusEquals(this.bias);
560                final Vector ydiag = CFMatrixUtils.diag(mult);
561                final Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, ydiag.getDimensionality());
562                createIdentity.setRow(0, ydiag);
563                return createIdentity;
564        }
565}