001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033
034import org.apache.logging.log4j.Logger;
035import org.apache.logging.log4j.LogManager;
036
037import org.openimaj.citation.annotation.Reference;
038import org.openimaj.citation.annotation.ReferenceType;
039import org.openimaj.math.matrix.CFMatrixUtils;
040
041/**
042 * An implementation of a stochastic gradient decent with proximal parameter
043 * adjustment (for regularised parameters).
044 * <p>
045 * Data is dealt with sequentially using a one pass implementation of the online
046 * proximal algorithm described in chapter 9 and 10 of: The Geometry of
047 * Constrained Structured Prediction: Applications to Inference and Learning of
048 * Natural Language Syntax, PhD, Andre T. Martins
049 * <p>
050 * This is a direct extension of the {@link BilinearSparseOnlineLearner} but
051 * instead of a mixed update scheme (i.e. for a number of iterations W and U are
052 * updated synchronously) we have an unmixed scheme where W is updated for a
053 * number of iterations, followed by U for a number of iterations continuing as
054 * a whole for a number of iterations
055 * <p>
056 * The implementation does the following:
057 * <ul>
058 * <li>When an X,Y is received:
059 * <ul>
060 * <li>Update currently held batch
061 * <li>If the batch is full:
062 * <ul>
063 * <li>While There is a great deal of change in U and W:
064 * <ul>
065 * <li>While There is a great deal of change in W:
066 * <ul>
067 * <li>Calculate the gradient of W holding U fixed
068 * <li>Proximal update of W
069 * <li>Calculate the gradient of Bias holding U and W fixed
070 * </ul>
071 * <li>While There is a great deal of change in U:
072 * <ul>
073 * <li>Calculate the gradient of U holding W fixed
074 * <li>Proximal update of U
075 * <li>Calculate the gradient of Bias holding U and W fixed
076 * </ul>
077 * </ul>
078 * <li>flush the batch
079 * </ul>
080 * <li>return current U and W (same as last time is batch isn't filled yet)
081 * </ul>
082 * </ul>
083 * 
084 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
085 * 
086 */
087@Reference(
088                author = { "Andre F. T. Martins" },
089                title = "The Geometry of Constrained Structured Prediction: Applications to Inference and Learning of Natural Language Syntax",
090                type = ReferenceType.Phdthesis,
091                year = "2012")
092public class BilinearUnmixedSparseOnlineLearner extends BilinearSparseOnlineLearner {
093
094        static Logger logger = LogManager.getLogger(BilinearUnmixedSparseOnlineLearner.class);
095
096        @Override
097        protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) {
098                Matrix current = currentW;
099                int iter = 0;
100                final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
101                final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
102                while (true) {
103                        final Matrix newcurrent = super.updateW(current, wLossWeighted, weightedLambda);
104                        final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent));
105                        final double total = CFMatrixUtils.absSum(current);
106                        final double ratio = sumchange / total;
107                        current = newcurrent;
108                        if (ratio < biconvextol || iter >= maxiter) {
109                                logger.debug("W tolerance reached after iteration: " + iter);
110                                break;
111                        }
112                        iter++;
113                }
114                return current;
115        }
116
117        @Override
118        protected Matrix updateU(Matrix currentU, Matrix neww, double uLossWeighted, double weightedLambda) {
119                Matrix current = currentU;
120                int iter = 0;
121                final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
122                final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
123                while (true) {
124                        final Matrix newcurrent = super.updateU(current, neww, uLossWeighted, weightedLambda);
125                        final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent));
126                        final double total = CFMatrixUtils.absSum(current);
127                        final double ratio = sumchange / total;
128                        current = newcurrent;
129                        if (ratio < biconvextol || iter >= maxiter) {
130                                logger.debug("U tolerance reached after iteration: " + iter);
131                                break;
132                        }
133                        iter++;
134                }
135                return current;
136        }
137}