001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.experiments.sinabill; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.Vector; 034import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ; 035 036import java.io.File; 037import java.io.IOException; 038import java.util.Arrays; 039 040import org.openimaj.io.IOUtils; 041import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator; 042import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode; 043import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 044import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner; 045import org.openimaj.ml.linear.learner.init.SingleValueInitStrat; 046import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy; 047import org.openimaj.util.pair.Pair; 048 049public class AustrianWordExperiments extends BilinearExperiment { 050 public static void main(String[] args) throws IOException { 051 final AustrianWordExperiments exp = new AustrianWordExperiments(); 052 exp.performExperiment(); 053 } 054 055 @Override 056 public void performExperiment() throws IOException { 057 final BilinearLearnerParameters params = new BilinearLearnerParameters(); 058 params.put(BilinearLearnerParameters.ETA0_U, 0.0002); 059 params.put(BilinearLearnerParameters.ETA0_W, 0.002); 060 params.put(BilinearLearnerParameters.LAMBDA, 0.001); 061 params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.05); 062 params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 5); 063 params.put(BilinearLearnerParameters.BIAS, true); 064 params.put(BilinearLearnerParameters.ETA0_BIAS, 0.05); 065 params.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1)); 066 params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy()); 067 final BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(MATLAB_DATA()), 98, true); 068 prepareExperimentLog(params); 069 final int fold = 0; 070 // File foldParamFile = new 071 // File(prepareExperimentRoot(),String.format("fold_%d_learner", fold)); 072 final File foldParamFile = new File( 073 "/Users/ss/Dropbox/TrendMiner/deliverables/year2-18month/Austrian Data/streamingExperiments/experiment_1365684128359/fold_0_learner"); 074 logger.debug("Fold: " + fold); 075 BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params); 076 learner.reinitParams(); 077 bmfdg.setFold(fold, Mode.TEST); 078 079 logger.debug("...training"); 080 bmfdg.setFold(fold, Mode.TRAINING); 081 int j = 0; 082 if (!foldParamFile.exists()) { 083 084 while (true) { 085 final Pair<Matrix> next = bmfdg.generate(); 086 if (next == null) 087 break; 088 logger.debug("...trying item " + j++); 089 learner.process(next.firstObject(), next.secondObject()); 090 } 091 System.out.println("Writing W and U to: " + foldParamFile); 092 IOUtils.writeBinary(foldParamFile, learner); 093 } else { 094 learner = IOUtils.read(foldParamFile, BilinearSparseOnlineLearner.class); 095 } 096 097 final Matrix w = learner.getW(); 098 final int ncols = w.getNumColumns(); 099 final int nwords = 20; 100 for (int c = 0; c < ncols; c++) { 101 System.out.println("Top " + nwords + " words for: " + bmfdg.getTasks()[c]); 102 final Vector col = w.getColumn(c); 103 final double[] wordWeights = new DenseVectorFactoryMTJ().copyVector(col).getArray(); 104 final Integer[] integerRange = ArrayIndexComparator.integerRange(wordWeights); 105 Arrays.sort(integerRange, new ArrayIndexComparator(wordWeights)); 106 for (int i = wordWeights.length - 1; i >= wordWeights.length - nwords; i--) { 107 System.out 108 .printf("%s: %1.5f\n", bmfdg.getVocabulary().get(integerRange[i]), wordWeights[integerRange[i]]); 109 } 110 } 111 } 112}