001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.experiments.sinabill; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033 034import java.io.File; 035import java.io.IOException; 036import java.util.ArrayList; 037import java.util.Collection; 038import java.util.List; 039 040import org.apache.logging.log4j.Logger; 041import org.apache.logging.log4j.LogManager; 042 043import org.openimaj.io.IOUtils; 044import org.openimaj.math.matrix.CFMatrixUtils; 045import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator; 046import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Fold; 047import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode; 048import org.openimaj.ml.linear.evaluation.BilinearEvaluator; 049import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator; 050import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 051import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner; 052import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy; 053import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction; 054import org.openimaj.util.pair.Pair; 055 056import com.google.common.primitives.Doubles; 057import com.jmatio.io.MatFileWriter; 058import com.jmatio.types.MLArray; 059 060/** 061 * Optimise lambda and eta0 and learning rates with a line search 062 * 063 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 064 */ 065public class LambdaSearchAustrian { 066 067 private static final int NFOLDS = 1; 068 private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/"; 069 private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments"; 070 private final Logger logger = LogManager.getLogger(getClass()); 071 072 /** 073 * @param args 074 * @throws IOException 075 */ 076 public static void main(String[] args) throws IOException { 077 final LambdaSearchAustrian exp = new LambdaSearchAustrian(); 078 exp.performExperiment(); 079 } 080 081 private long expStartTime = System.currentTimeMillis(); 082 083 /** 084 * @throws IOException 085 */ 086 public void performExperiment() throws IOException { 087 final List<BillMatlabFileDataGenerator.Fold> folds = prepareFolds(); 088 final BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator( 089 new File(dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA", 090 new File(dataFromRoot("unnormalised.mat")), 091 98, false, 092 folds 093 ); 094 prepareExperimentLog(); 095 final BilinearEvaluator eval = new RootMeanSumLossEvaluator(); 096 for (int i = 0; i < bmfdg.nFolds(); i++) { 097 logger.info("Starting Fold: " + i); 098 final BilinearSparseOnlineLearner best = lineSearchParams(i, bmfdg); 099 logger.debug("Best params found! Starting test..."); 100 bmfdg.setFold(i, Mode.TEST); 101 eval.setLearner(best); 102 final double ev = eval.evaluate(bmfdg.generateAll()); 103 logger.debug("Test RMSE: " + ev); 104 105 } 106 } 107 108 private BilinearSparseOnlineLearner lineSearchParams(int fold, BillMatlabFileDataGenerator source) { 109 BilinearSparseOnlineLearner best = null; 110 double bestScore = Double.MAX_VALUE; 111 final BilinearEvaluator eval = new RootMeanSumLossEvaluator(); 112 int j = 0; 113 final List<BilinearLearnerParameters> parameterLineSearch = parameterLineSearch(); 114 logger.info("Optimising params, searching: " + parameterLineSearch.size()); 115 for (final BilinearLearnerParameters next : parameterLineSearch) { 116 logger.info(String.format("Optimising params %d/%d", j + 1, parameterLineSearch.size())); 117 logger.debug("Current Params:\n" + next.toString()); 118 final BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(next); 119 // Train the model with the new parameters 120 source.setFold(fold, Mode.TRAINING); 121 Pair<Matrix> pair = null; 122 logger.debug("Training..."); 123 while ((pair = source.generate()) != null) { 124 learner.process(pair.firstObject(), pair.secondObject()); 125 } 126 logger.debug("Generating score of validation set"); 127 // validate with the validation set 128 source.setFold(fold, Mode.VALIDATION); 129 eval.setLearner(learner); 130 final double loss = eval.evaluate(source.generateAll()); 131 logger.debug("Total RMSE: " + loss); 132 logger.debug("U sparcity: " + CFMatrixUtils.sparsity(learner.getU())); 133 logger.debug("W sparcity: " + CFMatrixUtils.sparsity(learner.getW())); 134 // record the best 135 if (loss < bestScore) { 136 logger.info("New best score detected!"); 137 bestScore = loss; 138 best = learner; 139 logger.info("New Best Config:\n" + best.getParams()); 140 logger.info("New Best Loss:" + loss); 141 saveFoldParameterLearner(fold, j, learner); 142 } 143 j++; 144 } 145 return best; 146 } 147 148 private void saveFoldParameterLearner(int fold, int j, BilinearSparseOnlineLearner learner) { 149 // save the state 150 final File learnerOut = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format( 151 "learner_%d", j)); 152 final File learnerOutMat = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format( 153 "learner_%d.mat", j)); 154 learnerOut.getParentFile().mkdirs(); 155 try { 156 IOUtils.writeBinary(learnerOut, learner); 157 final Collection<MLArray> data = new ArrayList<MLArray>(); 158 data.add(CFMatrixUtils.toMLArray("u", learner.getU())); 159 data.add(CFMatrixUtils.toMLArray("w", learner.getW())); 160 if (learner.getBias() != null) { 161 data.add(CFMatrixUtils.toMLArray("b", learner.getBias())); 162 } 163 final MatFileWriter writer = new MatFileWriter(learnerOutMat, data); 164 } catch (final IOException e) { 165 throw new RuntimeException(e); 166 } 167 } 168 169 private List<BilinearLearnerParameters> parameterLineSearch() { 170 final BilinearLearnerParameters params = prepareParams(); 171 final BilinearLearnerParametersLineSearch iter = new BilinearLearnerParametersLineSearch(params); 172 173 iter.addIteration(BilinearLearnerParameters.ETA0_U, Doubles.asList(new double[] { 0.0001 })); 174 iter.addIteration(BilinearLearnerParameters.ETA0_W, Doubles.asList(new double[] { 0.005 })); 175 iter.addIteration(BilinearLearnerParameters.ETA0_BIAS, Doubles.asList(new double[] { 50 })); 176 iter.addIteration(BilinearLearnerParameters.LAMBDA_U, Doubles.asList(new double[] { 0.00001 })); 177 iter.addIteration(BilinearLearnerParameters.LAMBDA_W, Doubles.asList(new double[] { 0.00001 })); 178 179 final List<BilinearLearnerParameters> ret = new ArrayList<BilinearLearnerParameters>(); 180 for (final BilinearLearnerParameters param : iter) { 181 ret.add(param); 182 } 183 return ret; 184 } 185 186 private List<Fold> prepareFolds() { 187 final List<Fold> set_fold = new ArrayList<BillMatlabFileDataGenerator.Fold>(); 188 189 // [24/02/2014 16:58:23] .@bill: 190 final int step = 5; // % test_size 191 final int t_size = 48; // % training_size 192 final int v_size = 8; 193 for (int i = 0; i < NFOLDS; i++) { 194 final int total = i * step + t_size; 195 final int[] training = new int[total - v_size]; 196 final int[] test = new int[step]; 197 final int[] validation = new int[v_size]; 198 int j = 0; 199 int traini = 0; 200 final int tt = (int) Math.round(total / 2.) - 1; 201 for (; j < tt - v_size / 2; j++, traini++) { 202 training[traini] = j; 203 } 204 for (int k = 0; k < validation.length; k++, j++) { 205 validation[k] = j; 206 } 207 for (; j < total; j++, traini++) { 208 training[traini] = j; 209 } 210 for (int k = 0; k < test.length; k++, j++) { 211 test[k] = j; 212 } 213 final Fold foldi = new Fold(training, test, validation); 214 set_fold.add(foldi); 215 } 216 // [24/02/2014 16:59:07] .@bill: set_fold{1,1} 217 return set_fold; 218 } 219 220 private BilinearLearnerParameters prepareParams() { 221 final BilinearLearnerParameters params = new BilinearLearnerParameters(); 222 223 params.put(BilinearLearnerParameters.ETA0_U, null); 224 params.put(BilinearLearnerParameters.ETA0_W, null); 225 params.put(BilinearLearnerParameters.LAMBDA_U, null); 226 params.put(BilinearLearnerParameters.LAMBDA_W, null); 227 params.put(BilinearLearnerParameters.ETA0_BIAS, null); 228 229 params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01); 230 params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10); 231 params.put(BilinearLearnerParameters.BIAS, true); 232 params.put(BilinearLearnerParameters.WINITSTRAT, new SparseZerosInitStrategy()); 233 params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy()); 234 params.put(BilinearLearnerParameters.LOSS, new MatSquareLossFunction()); 235 return params; 236 } 237 238 /** 239 * @param data 240 * @return the data file from the root 241 */ 242 public static String dataFromRoot(String data) { 243 return String.format("%s/%s", ROOT, data); 244 } 245 246 protected void prepareExperimentLog() throws IOException { 247 // final ConsoleAppender console = new ConsoleAppender(); // create 248 // // appender 249 // // configure the appender 250 // final String PATTERN = "[%p->%C{1}] %m%n"; 251 // console.setLayout(new PatternLayout(PATTERN)); 252 // console.setThreshold(Level.INFO); 253 // console.activateOptions(); 254 // // add appender to any Logger (here is root) 255 // Logger.getRootLogger().addAppender(console); 256 // final File expRoot = prepareExperimentRoot(); 257 258 // final File logFile = new File(expRoot, "log"); 259 // if (logFile.exists()) 260 // logFile.delete(); 261 // final String TIMED_PATTERN = "[%d{HH:mm:ss} %p->%C{1}] %m%n"; 262 // final FileAppender file = new FileAppender(new PatternLayout(TIMED_PATTERN), logFile.getAbsolutePath()); 263 // file.setThreshold(Level.DEBUG); 264 // file.activateOptions(); 265 // Logger.getRootLogger().addAppender(file); 266 // logger.info("Experiment root: " + expRoot); 267 268 } 269 270 /** 271 * @return 272 * @throws IOException 273 */ 274 public File prepareExperimentRoot() throws IOException { 275 final String experimentRoot = currentOutputRoot(); 276 final File expRoot = new File(experimentRoot); 277 if (expRoot.exists() && expRoot.isDirectory()) 278 return expRoot; 279 logger.debug("Experiment root: " + expRoot); 280 if (!expRoot.mkdirs()) 281 throw new IOException("Couldn't prepare experiment output"); 282 return expRoot; 283 } 284 285 private String currentOutputRoot() { 286 return String.format("%s/%s/%s", OUTPUT_ROOT, getExperimentSetName(), "" + currentExperimentTime()); 287 } 288 289 private long currentExperimentTime() { 290 return expStartTime; 291 } 292 293 private String getExperimentSetName() { 294 return "streamingBilinear/optimiselambda"; 295 } 296}