001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
034
035import java.util.HashMap;
036import java.util.Map;
037import java.util.Map.Entry;
038
039import org.openimaj.util.pair.IndependentPair;
040import org.openimaj.util.pair.Pair;
041
042import com.google.common.collect.BiMap;
043import com.google.common.collect.HashBiMap;
044
045/**
046 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
047 * 
048 */
049public class IncrementalBilinearSparseOnlineLearner
050                implements
051                OnlineLearner<Map<String, Map<String, Double>>, Map<String, Double>>
052{
053        static class IncrementalBilinearSparseOnlineLearnerParams extends BilinearLearnerParameters {
054
055                /**
056         *
057         */
058                private static final long serialVersionUID = -1847045895118918210L;
059
060        }
061
062        private BiMap<String, Integer> vocabulary;
063        private BiMap<String, Integer> users;
064        private BiMap<String, Integer> values;
065        private BilinearSparseOnlineLearner bilinearLearner;
066        private BilinearLearnerParameters params;
067
068        /**
069         * Instantiates with the default params
070         */
071        public IncrementalBilinearSparseOnlineLearner() {
072                init(new IncrementalBilinearSparseOnlineLearnerParams());
073        }
074
075        /**
076         * @param params
077         */
078        public IncrementalBilinearSparseOnlineLearner(BilinearLearnerParameters params) {
079                init(params);
080        }
081
082        /**
083         *
084         */
085        public void reinitParams() {
086                init(this.params);
087
088        }
089
090        private void init(BilinearLearnerParameters params) {
091                vocabulary = HashBiMap.create();
092                users = HashBiMap.create();
093                values = HashBiMap.create();
094                this.params = params;
095                bilinearLearner = new BilinearSparseOnlineLearner(params);
096        }
097
098        /**
099         * @return the current parameters
100         */
101        public BilinearLearnerParameters getParams() {
102                return this.params;
103        }
104
105        @Override
106        public void process(Map<String, Map<String, Double>> x, Map<String, Double> y) {
107                updateUserValues(x, y);
108                final Matrix yMat = constructYMatrix(y);
109                final Matrix xMat = constructXMatrix(x);
110
111                this.bilinearLearner.process(xMat, yMat);
112        }
113
114        /**
115         * Update the incremental learner and underlying weight matricies to reflect
116         * potentially novel users, words and values to learn against
117         * 
118         * @param x
119         * @param y
120         */
121        public void updateUserValues(Map<String, Map<String, Double>> x, Map<String, Double> y) {
122                updateUserWords(x);
123                updateValues(y);
124        }
125
126        private void updateValues(Map<String, Double> y) {
127                for (final String value : y.keySet()) {
128                        if (!values.containsKey(value)) {
129                                values.put(value, values.size());
130                        }
131                }
132        }
133
134        private Matrix constructYMatrix(Map<String, Double> y) {
135                final Matrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, values.size());
136                for (final Entry<String, Double> ent : y.entrySet()) {
137                        mat.setElement(0, values.get(ent.getKey()), ent.getValue());
138                }
139                return mat;
140        }
141
142        private Map<String, Double> constructYMap(Matrix y) {
143                final Map<String, Double> ret = new HashMap<String, Double>();
144                for (final String key : values.keySet()) {
145                        final Integer index = values.get(key);
146                        final double yvalue = y.getElement(0, index);
147                        ret.put(key, yvalue);
148                }
149                return ret;
150        }
151
152        private Matrix constructXMatrix(Map<String, Map<String, Double>> x) {
153                final Matrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(vocabulary.size(), users.size());
154                for (final Entry<String, Map<String, Double>> userwords : x.entrySet()) {
155                        final int userindex = this.users.get(userwords.getKey());
156                        for (final Entry<String, Double> ent : userwords.getValue().entrySet()) {
157                                mat.setElement(vocabulary.get(ent.getKey()), userindex, ent.getValue());
158                        }
159                }
160                return mat;
161        }
162
163        private void updateUserWords(Map<String, Map<String, Double>> x) {
164                int newUsers = 0;
165                int newWords = 0;
166                for (final Entry<String, Map<String, Double>> userWords : x.entrySet()) {
167                        final String user = userWords.getKey();
168                        if (!users.containsKey(user)) {
169                                users.put(user, users.size());
170                                newUsers++;
171                        }
172                        newWords += updateWords(userWords.getValue());
173                }
174
175                this.bilinearLearner.addU(newUsers);
176                this.bilinearLearner.addW(newWords);
177        }
178
179        private int updateWords(Map<String, Double> value) {
180                int newWords = 0;
181                for (final String word : value.keySet()) {
182                        if (!vocabulary.containsKey(word)) {
183                                vocabulary.put(word, vocabulary.size());
184                                newWords++;
185                        }
186                }
187                return newWords;
188        }
189
190        /**
191         * Construct a learner with the desired number of users and words. If users
192         * and words beyond those in the current model are asked for their
193         * parameters are set to 0
194         * 
195         * @param nusers
196         * @param nwords
197         * @return a new {@link BilinearSparseOnlineLearner}
198         */
199        public BilinearSparseOnlineLearner getBilinearLearner(int nusers, int nwords) {
200                final BilinearSparseOnlineLearner ret = this.bilinearLearner.clone();
201
202                final Matrix u = ret.getU();
203                final Matrix w = ret.getW();
204                final Matrix newu = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, u.getNumColumns());
205                final Matrix neww = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, w.getNumColumns());
206
207                newu.setSubMatrix(0, 0, u);
208                neww.setSubMatrix(0, 0, w);
209
210                ret.setU(newu);
211                ret.setW(neww);
212                return ret;
213        }
214
215        /**
216         * @return the underlying {@link BilinearSparseOnlineLearner} with the
217         *         current number of users and words
218         */
219        public BilinearSparseOnlineLearner getBilinearLearner() {
220                return this.bilinearLearner.clone();
221        }
222
223        /**
224         * Given a sparse pair of user/words and value construct a pair of matricies
225         * using the current mappings of words and users to matrix rows.
226         * 
227         * Note: if users or words which have not yet be
228         * 
229         * @param xy
230         * @param nfeatures
231         *            the number of words total in the returned X matrix
232         * @param nusers
233         *            the number of users total in the returned X matrix
234         * @param ntasks
235         *            the number of tasks in the returned Y matrix
236         * @return the matrix pair representing the X and Y input
237         */
238        public Pair<Matrix> asMatrixPair(
239                        IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> xy,
240                        int nfeatures, int nusers, int ntasks)
241        {
242                final Matrix y = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, ntasks);
243                final Matrix x = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nfeatures, nusers);
244                final Map<String, Double> ymap = xy.secondObject();
245                final Map<String, Map<String, Double>> userFeatureMap = xy.firstObject();
246                for (final Entry<String, Double> yent : ymap.entrySet()) {
247                        y.setElement(0, this.values.get(yent.getKey()), yent.getValue());
248                }
249                for (final Entry<String, Map<String, Double>> xent : userFeatureMap.entrySet()) {
250                        final int userind = this.users.get(xent.getKey());
251                        for (final Entry<String, Double> fent : xent.getValue().entrySet()) {
252                                x.setElement(this.vocabulary.get(fent.getKey()), userind, fent.getValue());
253                        }
254                }
255                return new Pair<Matrix>(x, y);
256        }
257
258        @Override
259        public Map<String, Double> predict(Map<String, Map<String, Double>> x) {
260                final Matrix xMat = constructXMatrix(x);
261                final Matrix yMat = this.bilinearLearner.predict(xMat);
262                return this.constructYMap(yMat);
263        }
264
265        /**
266         * @return the vocabulary
267         */
268        public BiMap<String, Integer> getVocabulary() {
269                return vocabulary;
270        }
271
272        /**
273         * @param in
274         * @return calls {@link #asMatrixPair(IndependentPair, int, int, int)} with
275         *         the current number of words, users and value
276         */
277        public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> in) {
278                return this.asMatrixPair(in, this.vocabulary.size(), this.users.size(), this.values.size());
279        }
280
281        /**
282         * @param x
283         * @param y
284         * @return calls {@link #asMatrixPair(IndependentPair, int, int, int)} with
285         *         the current number of words, users and value
286         */
287        public Pair<Matrix> asMatrixPair(Map<String, Map<String, Double>> x, Map<String, Double> y) {
288                return this.asMatrixPair(IndependentPair.pair(x, y), this.vocabulary.size(), this.users.size(),
289                                this.values.size());
290        }
291
292        /**
293         * @return the current map of dependent values to indexes
294         */
295        public BiMap<String, Integer> getDependantValues() {
296                return this.values;
297        }
298
299        public BiMap<String, Integer> getUsers() {
300                return this.users;
301        }
302}