001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.experiments.sinabill;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033
034import java.io.File;
035import java.io.IOException;
036import java.util.ArrayList;
037import java.util.Collection;
038import java.util.List;
039
040import org.apache.logging.log4j.Logger;
041import org.apache.logging.log4j.LogManager;
042
043import org.openimaj.io.IOUtils;
044import org.openimaj.math.matrix.CFMatrixUtils;
045import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
046import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Fold;
047import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode;
048import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
049import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
050import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
051import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
052import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
053import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
054import org.openimaj.util.pair.Pair;
055
056import com.google.common.primitives.Doubles;
057import com.jmatio.io.MatFileWriter;
058import com.jmatio.types.MLArray;
059
060/**
061 * Optimise lambda and eta0 and learning rates with a line search
062 * 
063 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
064 */
065public class LambdaSearchAustrian {
066
067        private static final int NFOLDS = 1;
068        private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/";
069        private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments";
070        private final Logger logger = LogManager.getLogger(getClass());
071
072        /**
073         * @param args
074         * @throws IOException
075         */
076        public static void main(String[] args) throws IOException {
077                final LambdaSearchAustrian exp = new LambdaSearchAustrian();
078                exp.performExperiment();
079        }
080
081        private long expStartTime = System.currentTimeMillis();
082
083        /**
084         * @throws IOException
085         */
086        public void performExperiment() throws IOException {
087                final List<BillMatlabFileDataGenerator.Fold> folds = prepareFolds();
088                final BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(
089                                new File(dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA",
090                                new File(dataFromRoot("unnormalised.mat")),
091                                98, false,
092                                folds
093                                );
094                prepareExperimentLog();
095                final BilinearEvaluator eval = new RootMeanSumLossEvaluator();
096                for (int i = 0; i < bmfdg.nFolds(); i++) {
097                        logger.info("Starting Fold: " + i);
098                        final BilinearSparseOnlineLearner best = lineSearchParams(i, bmfdg);
099                        logger.debug("Best params found! Starting test...");
100                        bmfdg.setFold(i, Mode.TEST);
101                        eval.setLearner(best);
102                        final double ev = eval.evaluate(bmfdg.generateAll());
103                        logger.debug("Test RMSE: " + ev);
104
105                }
106        }
107
108        private BilinearSparseOnlineLearner lineSearchParams(int fold, BillMatlabFileDataGenerator source) {
109                BilinearSparseOnlineLearner best = null;
110                double bestScore = Double.MAX_VALUE;
111                final BilinearEvaluator eval = new RootMeanSumLossEvaluator();
112                int j = 0;
113                final List<BilinearLearnerParameters> parameterLineSearch = parameterLineSearch();
114                logger.info("Optimising params, searching: " + parameterLineSearch.size());
115                for (final BilinearLearnerParameters next : parameterLineSearch) {
116                        logger.info(String.format("Optimising params %d/%d", j + 1, parameterLineSearch.size()));
117                        logger.debug("Current Params:\n" + next.toString());
118                        final BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(next);
119                        // Train the model with the new parameters
120                        source.setFold(fold, Mode.TRAINING);
121                        Pair<Matrix> pair = null;
122                        logger.debug("Training...");
123                        while ((pair = source.generate()) != null) {
124                                learner.process(pair.firstObject(), pair.secondObject());
125                        }
126                        logger.debug("Generating score of validation set");
127                        // validate with the validation set
128                        source.setFold(fold, Mode.VALIDATION);
129                        eval.setLearner(learner);
130                        final double loss = eval.evaluate(source.generateAll());
131                        logger.debug("Total RMSE: " + loss);
132                        logger.debug("U sparcity: " + CFMatrixUtils.sparsity(learner.getU()));
133                        logger.debug("W sparcity: " + CFMatrixUtils.sparsity(learner.getW()));
134                        // record the best
135                        if (loss < bestScore) {
136                                logger.info("New best score detected!");
137                                bestScore = loss;
138                                best = learner;
139                                logger.info("New Best Config:\n" + best.getParams());
140                                logger.info("New Best Loss:" + loss);
141                                saveFoldParameterLearner(fold, j, learner);
142                        }
143                        j++;
144                }
145                return best;
146        }
147
148        private void saveFoldParameterLearner(int fold, int j, BilinearSparseOnlineLearner learner) {
149                // save the state
150                final File learnerOut = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format(
151                                "learner_%d", j));
152                final File learnerOutMat = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format(
153                                "learner_%d.mat", j));
154                learnerOut.getParentFile().mkdirs();
155                try {
156                        IOUtils.writeBinary(learnerOut, learner);
157                        final Collection<MLArray> data = new ArrayList<MLArray>();
158                        data.add(CFMatrixUtils.toMLArray("u", learner.getU()));
159                        data.add(CFMatrixUtils.toMLArray("w", learner.getW()));
160                        if (learner.getBias() != null) {
161                                data.add(CFMatrixUtils.toMLArray("b", learner.getBias()));
162                        }
163                        final MatFileWriter writer = new MatFileWriter(learnerOutMat, data);
164                } catch (final IOException e) {
165                        throw new RuntimeException(e);
166                }
167        }
168
169        private List<BilinearLearnerParameters> parameterLineSearch() {
170                final BilinearLearnerParameters params = prepareParams();
171                final BilinearLearnerParametersLineSearch iter = new BilinearLearnerParametersLineSearch(params);
172
173                iter.addIteration(BilinearLearnerParameters.ETA0_U, Doubles.asList(new double[] { 0.0001 }));
174                iter.addIteration(BilinearLearnerParameters.ETA0_W, Doubles.asList(new double[] { 0.005 }));
175                iter.addIteration(BilinearLearnerParameters.ETA0_BIAS, Doubles.asList(new double[] { 50 }));
176                iter.addIteration(BilinearLearnerParameters.LAMBDA_U, Doubles.asList(new double[] { 0.00001 }));
177                iter.addIteration(BilinearLearnerParameters.LAMBDA_W, Doubles.asList(new double[] { 0.00001 }));
178
179                final List<BilinearLearnerParameters> ret = new ArrayList<BilinearLearnerParameters>();
180                for (final BilinearLearnerParameters param : iter) {
181                        ret.add(param);
182                }
183                return ret;
184        }
185
186        private List<Fold> prepareFolds() {
187                final List<Fold> set_fold = new ArrayList<BillMatlabFileDataGenerator.Fold>();
188
189                // [24/02/2014 16:58:23] .@bill:
190                final int step = 5; // % test_size
191                final int t_size = 48; // % training_size
192                final int v_size = 8;
193                for (int i = 0; i < NFOLDS; i++) {
194                        final int total = i * step + t_size;
195                        final int[] training = new int[total - v_size];
196                        final int[] test = new int[step];
197                        final int[] validation = new int[v_size];
198                        int j = 0;
199                        int traini = 0;
200                        final int tt = (int) Math.round(total / 2.) - 1;
201                        for (; j < tt - v_size / 2; j++, traini++) {
202                                training[traini] = j;
203                        }
204                        for (int k = 0; k < validation.length; k++, j++) {
205                                validation[k] = j;
206                        }
207                        for (; j < total; j++, traini++) {
208                                training[traini] = j;
209                        }
210                        for (int k = 0; k < test.length; k++, j++) {
211                                test[k] = j;
212                        }
213                        final Fold foldi = new Fold(training, test, validation);
214                        set_fold.add(foldi);
215                }
216                // [24/02/2014 16:59:07] .@bill: set_fold{1,1}
217                return set_fold;
218        }
219
220        private BilinearLearnerParameters prepareParams() {
221                final BilinearLearnerParameters params = new BilinearLearnerParameters();
222
223                params.put(BilinearLearnerParameters.ETA0_U, null);
224                params.put(BilinearLearnerParameters.ETA0_W, null);
225                params.put(BilinearLearnerParameters.LAMBDA_U, null);
226                params.put(BilinearLearnerParameters.LAMBDA_W, null);
227                params.put(BilinearLearnerParameters.ETA0_BIAS, null);
228
229                params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01);
230                params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10);
231                params.put(BilinearLearnerParameters.BIAS, true);
232                params.put(BilinearLearnerParameters.WINITSTRAT, new SparseZerosInitStrategy());
233                params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
234                params.put(BilinearLearnerParameters.LOSS, new MatSquareLossFunction());
235                return params;
236        }
237
238        /**
239         * @param data
240         * @return the data file from the root
241         */
242        public static String dataFromRoot(String data) {
243                return String.format("%s/%s", ROOT, data);
244        }
245
246        protected void prepareExperimentLog() throws IOException {
247                // final ConsoleAppender console = new ConsoleAppender(); // create
248                //                                                                                                              // appender
249                // // configure the appender
250                // final String PATTERN = "[%p->%C{1}] %m%n";
251                // console.setLayout(new PatternLayout(PATTERN));
252                // console.setThreshold(Level.INFO);
253                // console.activateOptions();
254                // // add appender to any Logger (here is root)
255                // Logger.getRootLogger().addAppender(console);
256                // final File expRoot = prepareExperimentRoot();
257
258                // final File logFile = new File(expRoot, "log");
259                // if (logFile.exists())
260                //      logFile.delete();
261                // final String TIMED_PATTERN = "[%d{HH:mm:ss} %p->%C{1}] %m%n";
262                // final FileAppender file = new FileAppender(new PatternLayout(TIMED_PATTERN), logFile.getAbsolutePath());
263                // file.setThreshold(Level.DEBUG);
264                // file.activateOptions();
265                // Logger.getRootLogger().addAppender(file);
266                // logger.info("Experiment root: " + expRoot);
267
268        }
269
270        /**
271         * @return
272         * @throws IOException
273         */
274        public File prepareExperimentRoot() throws IOException {
275                final String experimentRoot = currentOutputRoot();
276                final File expRoot = new File(experimentRoot);
277                if (expRoot.exists() && expRoot.isDirectory())
278                        return expRoot;
279                logger.debug("Experiment root: " + expRoot);
280                if (!expRoot.mkdirs())
281                        throw new IOException("Couldn't prepare experiment output");
282                return expRoot;
283        }
284
285        private String currentOutputRoot() {
286                return String.format("%s/%s/%s", OUTPUT_ROOT, getExperimentSetName(), "" + currentExperimentTime());
287        }
288
289        private long currentExperimentTime() {
290                return expStartTime;
291        }
292
293        private String getExperimentSetName() {
294                return "streamingBilinear/optimiselambda";
295        }
296}