001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.experiments.sinabill; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.MatrixFactory; 034 035import java.io.File; 036import java.io.IOException; 037import java.util.ArrayList; 038import java.util.List; 039 040import org.openimaj.io.IOUtils; 041import org.openimaj.math.matrix.CFMatrixUtils; 042import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator; 043import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode; 044import org.openimaj.ml.linear.evaluation.BilinearEvaluator; 045import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator; 046import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 047import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner; 048import org.openimaj.ml.linear.learner.init.SingleValueInitStrat; 049import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy; 050import org.openimaj.util.pair.Pair; 051 052 053public class BillAustrianDampeningExperiments extends BilinearExperiment{ 054 055 public static void main(String[] args) throws Exception { 056 BilinearExperiment exp = new BillAustrianDampeningExperiments(); 057 exp.performExperiment(); 058 } 059 060 @Override 061 public void performExperiment() throws IOException { 062 BilinearLearnerParameters params = new BilinearLearnerParameters(); 063 params.put(BilinearLearnerParameters.ETA0_U, 0.02); 064 params.put(BilinearLearnerParameters.ETA0_W, 0.02); 065 params.put(BilinearLearnerParameters.LAMBDA, 0.001); 066 params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01); 067 params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10); 068 params.put(BilinearLearnerParameters.BIAS, true); 069 params.put(BilinearLearnerParameters.ETA0_BIAS, 0.5); 070 params.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1)); 071 params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy()); 072// params.put(BilinearLearnerParameters.DAMPENING, 0.1); 073 BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator( 074 new File(MATLAB_DATA()), 075 98, 076 true 077 ); 078 prepareExperimentLog(params); 079 int foldNumber = 5; 080 logger.debug("Starting dampening experiments"); 081 logger.debug("Fold: " + foldNumber); 082 bmfdg.setFold(foldNumber, Mode.TEST); 083 List<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>(); 084 while(true){ 085 Pair<Matrix> next = bmfdg.generate(); 086 if(next == null) break; 087 testpairs.add(next); 088 } 089 double dampening = 0d; 090 double dampeningIncr = 0.0001d; 091 double dampeningMax = 0.02d; 092 logger.debug( 093 String.format( 094 "Beggining dampening experiments: min=%2.5f,max=%2.5f,incr=%2.5f", 095 dampening, 096 dampeningMax, 097 dampeningIncr 098 099 )); 100 while(dampening < dampeningMax){ 101 params.put(BilinearLearnerParameters.DAMPENING, dampening); 102 BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params); 103 learner.reinitParams(); 104 105 logger.debug("Dampening is now: " + dampening); 106 logger.debug("...training"); 107 bmfdg.setFold(foldNumber, Mode.TRAINING); 108 int j = 0; 109 while(true){ 110 Pair<Matrix> next = bmfdg.generate(); 111 if(next == null) break; 112 logger.debug("...trying item "+j++); 113 learner.process(next.firstObject(), next.secondObject()); 114 Matrix u = learner.getU(); 115 Matrix w = learner.getW(); 116 Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias()); 117 BilinearEvaluator eval = new RootMeanSumLossEvaluator(); 118 eval.setLearner(learner); 119 double loss = eval.evaluate(testpairs); 120 logger.debug(String.format("Saving learner, Fold %d, Item %d",foldNumber, j)); 121 File learnerOut = new File(FOLD_ROOT(foldNumber),String.format("learner_%d_dampening=%2.5f",j,dampening)); 122 IOUtils.writeBinary(learnerOut, learner); 123 logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w)); 124 logger.debug(String.format("W range: %2.5f -> %2.5f",CFMatrixUtils.min(w), CFMatrixUtils.max(w))); 125 logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u)); 126 logger.debug(String.format("U range: %2.5f -> %2.5f",CFMatrixUtils.min(u), CFMatrixUtils.max(u))); 127 Boolean biasMode = learner.getParams().getTyped(BilinearLearnerParameters.BIAS); 128 if(biasMode){ 129 logger.debug("Bias: " + CFMatrixUtils.diag(bias)); 130 } 131 logger.debug(String.format("... loss: %f",loss)); 132 } 133 134 dampening+=dampeningIncr; 135 } 136 } 137 138} 139 140