001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner.matlib;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035
036import org.apache.logging.log4j.Logger;
037import org.apache.logging.log4j.LogManager;
038
039import org.openimaj.io.ReadWriteableBinary;
040import org.openimaj.math.matrix.DiagonalMatrix;
041import org.openimaj.math.matrix.MatlibMatrixUtils;
042import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
043import org.openimaj.ml.linear.learner.OnlineLearner;
044import org.openimaj.ml.linear.learner.matlib.init.InitStrategy;
045import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat;
046import org.openimaj.ml.linear.learner.matlib.loss.LossFunction;
047import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction;
048import org.openimaj.ml.linear.learner.matlib.regul.Regulariser;
049
050import ch.akuhn.matrix.Matrix;
051import ch.akuhn.matrix.SparseMatrix;
052
053
054/**
055 * An implementation of a stochastic gradient decent with proximal perameter adjustment
056 * (for regularised parameters).
057 *
058 * Data is dealt with sequentially using a one pass implementation of the
059 * online proximal algorithm described in chapter 9 and 10 of:
060 * The Geometry of Constrained Structured Prediction: Applications to Inference and
061 * Learning of Natural Language Syntax, PhD, Andre T. Martins
062 *
063 * The implementation does the following:
064 *      - When an X,Y is recieved:
065 *              - Update currently held batch
066 *              - If the batch is full:
067 *                      - While There is a great deal of change in U and W:
068 *                              - Calculate the gradient of W holding U fixed
069 *                              - Proximal update of W
070 *                              - Calculate the gradient of U holding W fixed
071 *                              - Proximal update of U
072 *                              - Calculate the gradient of Bias holding U and W fixed
073 *                      - flush the batch
074 *              - return current U and W (same as last time is batch isn't filled yet)
075 *
076 *
077 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
078 *
079 */
080public class MatlibBilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{
081
082        static Logger logger = LogManager.getLogger(MatlibBilinearSparseOnlineLearner.class);
083
084        protected BilinearLearnerParameters params;
085        protected Matrix w;
086        protected Matrix u;
087        protected LossFunction loss;
088        protected Regulariser regul;
089        protected Double lambda_w,lambda_u;
090        protected Boolean biasMode;
091        protected Matrix bias;
092        protected Matrix diagX;
093        protected Double eta0_u;
094        protected Double eta0_w;
095
096        private Boolean forceSparcity;
097
098        private Boolean zStandardise;
099
100        private boolean nodataseen;
101
102        /**
103         * The default parameters. These won't work with your dataset, i promise.
104         */
105        public MatlibBilinearSparseOnlineLearner() {
106                this(new BilinearLearnerParameters());
107        }
108        /**
109         * @param params the parameters used by this learner
110         */
111        public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters params) {
112                this.params = params;
113                reinitParams();
114        }
115
116        /**
117         * must be called if any parameters are changed
118         */
119        public void reinitParams() {
120                this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS);
121                this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL);
122                this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
123                this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
124                this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS);
125                this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U);
126                this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W);
127                this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
128                this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE);
129                if(!this.loss.isMatrixLoss())
130                        this.loss = new MatLossFunction(this.loss);
131                this.nodataseen = true;
132        }
133        private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
134                final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y);
135                final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y);
136                this.w = wstrat.init(xrows, ycols);
137                this.u = ustrat.init(xcols, ycols);
138
139                this.bias = SparseMatrix.sparse(ycols,ycols);
140                if(this.biasMode){
141                        final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y);
142                        this.bias = bstrat.init(ycols, ycols);
143                        this.diagX = new DiagonalMatrix(ycols,1);
144                }
145        }
146
147        private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
148                final InitStrategy strat = this.params.getTyped(initstrat);
149                return strat;
150        }
151        @Override
152        public void process(Matrix X, Matrix Y){
153                final int nfeatures = X.rowCount();
154                final int nusers = X.columnCount();
155                final int ntasks = Y.columnCount();
156//              int ninstances = Y.rowCount(); // Assume 1 instance!
157                
158                // only inits when the current params is null
159                if (this.w == null){
160                        initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks
161                }
162
163                final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING);
164                final double weighting = 1.0 - dampening ;
165
166                logger.debug("... dampening w, u and bias by: " + weighting);
167
168                // Adjust for weighting
169                MatlibMatrixUtils.scaleInplace(this.w,weighting);
170                MatlibMatrixUtils.scaleInplace(this.u,weighting);
171                if(this.biasMode){
172                        MatlibMatrixUtils.scaleInplace(this.bias,weighting);
173                }
174                // First expand Y s.t. blocks of rows contain the task values for each row of Y.
175                // This means Yexp has (n * t x t)
176                final SparseMatrix Yexp = expandY(Y);
177                loss.setY(Yexp);
178                int iter = 0;
179                while(true) {
180                        // We need to set the bias here because it is used in the loss calculation of U and W
181                        if(this.biasMode) loss.setBias(this.bias);
182                        iter += 1;
183
184                        final double uLossWeight = etat(iter,eta0_u);
185                        final double wLossWeighted = etat(iter,eta0_w);
186                        final double weightedLambda_u = lambdat(iter,lambda_u);
187                        final double weightedLambda_w = lambdat(iter,lambda_w);
188                        // Dprime is tasks x nwords
189                        Matrix Dprime = null;
190                        if(this.nodataseen){
191                                this.nodataseen = false;
192                                Matrix fakeut = new SparseSingleValueInitStrat(1).init(this.u.columnCount(),this.u.rowCount());
193                                Dprime = MatlibMatrixUtils.dotProductTranspose(fakeut, X); // i.e. fakeut . X^T
194                        } else {
195                                Dprime = MatlibMatrixUtils.dotProductTransposeTranspose(u, X); // i.e. u^T . X^T
196                        }
197                        
198                        // ... as is the cost function's X
199                        if(zStandardise){
200//                              Vector rowMean = CFMatrixUtils.rowMean(Dprime);
201//                              CFMatrixUtils.minusEqualsCol(Dprime,rowMean);
202                        }
203                        loss.setX(Dprime);
204                        final Matrix neww = updateW(this.w,wLossWeighted, weightedLambda_w);
205
206                        // Vprime is nusers x tasks
207                        final Matrix Vt = MatlibMatrixUtils.transposeDotProduct(neww,X); // i.e. (X^T.neww)^T X.transpose().times(neww);
208                        // ... so the loss function's X is (tasks x nusers)
209                        loss.setX(Vt);
210                        final Matrix newu = updateU(this.u,uLossWeight, weightedLambda_u);
211                        
212                        final double sumchangew = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(neww, this.w));
213                        final double totalw = MatlibMatrixUtils.normF(this.w);
214
215                        final double sumchangeu = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newu, this.u));
216                        final double totalu = MatlibMatrixUtils.normF(this.u);
217
218                        double ratioU = 0;
219                        if(totalu!=0) ratioU = sumchangeu/totalu;
220                        final double ratioW = 0;
221                        if(totalw!=0) ratioU = sumchangew/totalw;
222                        double ratioB = 0;
223                        double ratio = ratioU + ratioW;
224                        double totalbias = 0;
225                        if(this.biasMode){
226                                Matrix mult = MatlibMatrixUtils.dotProductTransposeTranspose(newu, X);
227                                mult = MatlibMatrixUtils.dotProduct(mult, neww);                                
228                                MatlibMatrixUtils.plusInplace(mult, bias);
229                                // We must set bias to null!
230                                loss.setBias(null);
231                                loss.setX(diagX);
232                                // Calculate gradient of bias (don't regularise)
233                                final Matrix biasGrad = loss.gradient(mult);
234                                final double biasLossWeight = biasEtat(iter);
235                                final Matrix newbias = updateBias(biasGrad, biasLossWeight);
236
237                                final double sumchangebias = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newbias, bias));
238                                totalbias = MatlibMatrixUtils.normF(this.bias);
239                                if(totalbias!=0) ratioB = (sumchangebias/totalbias) ;
240                                this.bias = newbias;
241                                ratio += ratioB;
242                                ratio/=3;
243                        }
244                        else{
245                                ratio/=2;
246                        }
247
248                        final Double biconvextol = this.params.getTyped("biconvex_tol");
249                        final Integer maxiter = this.params.getTyped("biconvex_maxiter");
250                        if(iter%3 == 0){
251                                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio));
252                                logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w));
253                                logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u));
254                                logger.debug("Total U magnitude: " + totalu);
255                                logger.debug("Total W magnitude: " + totalw);
256                                logger.debug("Total Bias: " + totalbias);
257                        }
258                        if(biconvextol  < 0 || ratio < biconvextol || iter >= maxiter) {
259                                logger.debug("tolerance reached after iteration: " + iter);
260                                logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w));
261                                logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u));
262                                logger.debug("Total U magnitude: " + totalu);
263                                logger.debug("Total W magnitude: " + totalw);
264                                logger.debug("Total Bias: " + totalbias);
265                                break;
266                        }
267                }
268        }
269        
270        protected Matrix updateBias(Matrix biasGrad, double biasLossWeight) {
271                final Matrix newbias = MatlibMatrixUtils.minus(
272                                this.bias,
273                                MatlibMatrixUtils.scaleInplace(
274                                                biasGrad,
275                                                biasLossWeight
276                                )
277                );
278                return newbias;
279        }
280        protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) {
281                final Matrix gradW = loss.gradient(currentW);
282                MatlibMatrixUtils.scaleInplace(gradW,wLossWeighted);
283
284                Matrix neww = MatlibMatrixUtils.minus(currentW,gradW);
285                neww = regul.prox(neww, weightedLambda);
286                return neww;
287        }
288        protected Matrix updateU(Matrix currentU, double uLossWeight, double uWeightedLambda) {
289                final Matrix gradU = loss.gradient(currentU);
290                MatlibMatrixUtils.scaleInplace(gradU,uLossWeight);
291                Matrix newu = MatlibMatrixUtils.minus(currentU,gradU);
292                newu = regul.prox(newu, uWeightedLambda);
293                return newu;
294        }
295        private double lambdat(int iter, double lambda) {
296                return lambda/iter;
297        }
298        /**
299         * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal
300         * @param Y
301         * @return the diagonalised Y
302         */
303        public static SparseMatrix expandY(Matrix Y) {
304                final int ntasks = Y.columnCount();
305                final SparseMatrix Yexp = SparseMatrix.sparse(ntasks, ntasks);
306                for (int touter = 0; touter < ntasks; touter++) {
307                        for (int tinner = 0; tinner < ntasks; tinner++) {
308                                if(tinner == touter){
309                                        Yexp.put(touter, tinner, Y.get(0, tinner));
310                                }
311                                else{
312                                        Yexp.put(touter, tinner, Double.NaN);
313                                }
314                        }
315                }
316                return Yexp;
317        }
318        private double biasEtat(int iter){
319                final Double biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS);
320                return biasEta0 / Math.sqrt(iter);
321        }
322
323
324        private double etat(int iter,double eta0) {
325                final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS);
326                final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps));
327                return eta(eta0) / sqrtCeil;
328        }
329        private double eta(double eta0) {
330                return eta0 ;
331        }
332
333
334
335        /**
336         * @return the current apramters
337         */
338        public BilinearLearnerParameters getParams() {
339                return this.params;
340        }
341
342        /**
343         * @return the current user matrix
344         */
345        public Matrix getU(){
346                return this.u;
347        }
348
349        /**
350         * @return the current word matrix
351         */
352        public Matrix getW(){
353                return this.w;
354        }
355        /**
356         * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false
357         */
358        public Matrix getBias() {
359                if(this.biasMode)
360                        return this.bias;
361                else
362                        return null;
363        }
364
365        /**
366         * Expand the U parameters matrix by added a set of rows.
367         * If currently unset, this function does nothing (assuming U will be initialised in the first round)
368         * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT}
369         * @param newUsers the number of new users to add
370         */
371        public void addU(int newUsers) {
372                if(this.u == null) return; // If u has not be inited, then it will be on first process
373                final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null);
374                final Matrix newU = ustrat.init(newUsers, this.u.columnCount());
375                this.u = MatlibMatrixUtils.vstack(this.u,newU);
376        }
377
378        /**
379         * Expand the W parameters matrix by added a set of rows.
380         * If currently unset, this function does nothing (assuming W will be initialised in the first round)
381         * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT}
382         * @param newWords the number of new words to add
383         */
384        public void addW(int newWords) {
385                if(this.w == null) return; // If w has not be inited, then it will be on first process
386                final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null);
387                final Matrix newW = wstrat.init(newWords, this.w.columnCount());
388                this.w = MatlibMatrixUtils.vstack(this.w,newW);
389        }
390
391        @Override
392        public MatlibBilinearSparseOnlineLearner clone(){
393                final MatlibBilinearSparseOnlineLearner ret = new MatlibBilinearSparseOnlineLearner(this.getParams());
394                ret.u = MatlibMatrixUtils.copy(this.u);
395                ret.w = MatlibMatrixUtils.copy(this.w);
396                if(this.biasMode){
397                        ret.bias = MatlibMatrixUtils.copy(this.bias);
398                }
399                return ret;
400        }
401        /**
402         * @param newu set the model's U
403         */
404        public void setU(Matrix newu) {
405                this.u = newu;
406        }
407
408        /**
409         * @param neww set the model's W
410         */
411        public void setW(Matrix neww) {
412                this.w = neww;
413        }
414        @Override
415        public void readBinary(DataInput in) throws IOException {
416                final int nwords = in.readInt();
417                final int nusers = in.readInt();
418                final int ntasks = in.readInt();
419
420
421                this.w = SparseMatrix.sparse(nwords, ntasks);
422                for (int t = 0; t < ntasks; t++) {
423                        for (int r = 0; r < nwords; r++) {
424                                final double readDouble = in.readDouble();
425                                if(readDouble != 0){
426                                        this.w.put(r, t, readDouble);
427                                }
428                        }
429                }
430
431                this.u = SparseMatrix.sparse(nusers, ntasks);
432                for (int t = 0; t < ntasks; t++) {
433                        for (int r = 0; r < nusers; r++) {
434                                final double readDouble = in.readDouble();
435                                if(readDouble != 0){
436                                        this.u.put(r, t, readDouble);
437                                }
438                        }
439                }
440
441                this.bias = SparseMatrix.sparse(ntasks, ntasks);
442                for (int t1 = 0; t1 < ntasks; t1++) {
443                        for (int t2 = 0; t2 < ntasks; t2++) {
444                                final double readDouble = in.readDouble();
445                                if(readDouble != 0){
446                                        this.bias.put(t1, t2, readDouble);
447                                }
448                        }
449                }
450        }
451        @Override
452        public byte[] binaryHeader() {
453                return "".getBytes();
454        }
455        @Override
456        public void writeBinary(DataOutput out) throws IOException {
457                out.writeInt(w.rowCount());
458                out.writeInt(u.rowCount());
459                out.writeInt(u.columnCount());
460                final double[] wdata = w.asColumnMajorArray();
461                for (int i = 0; i < wdata.length; i++) {
462                        out.writeDouble(wdata[i]);
463                }
464                final double[] udata = u.asColumnMajorArray();
465                for (int i = 0; i < udata.length; i++) {
466                        out.writeDouble(udata[i]);
467                }
468                final double[] biasdata = bias.asColumnMajorArray();
469                for (int i = 0; i < biasdata.length; i++) {
470                        out.writeDouble(biasdata[i]);
471                }
472        }
473
474
475        @Override
476        public Matrix predict(Matrix x) {
477                Matrix xt = MatlibMatrixUtils.transpose(x);
478                final Matrix mult = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.transpose(u), xt),this.w);
479                if(this.biasMode) MatlibMatrixUtils.plusInplace(mult,this.bias);
480                Matrix ydiag = new DiagonalMatrix(mult);
481                return ydiag;
482        }
483}