001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.matrix;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.MatrixEntry;
034import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
035
036import org.openimaj.util.function.Operation;
037import org.openimaj.util.parallel.Parallel;
038
039/**
040 * Perform a multithreaded matrix multiplication
041 * 
042 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
043 * 
044 */
045public class ThreadedMatrixMulti {
046        private double[][] answer;
047        private double[][] a;
048        private double[][] b;
049        private int answerCols;
050        private int answerRows;
051
052        public ThreadedMatrixMulti() {
053
054        }
055
056        public ThreadedMatrixMulti(int numRows, int numCols) {
057                this.newAnswer(numRows, numCols);
058        }
059
060        class MultiplicationOperation implements Operation<Integer> {
061                @Override
062                public void perform(Integer object) {
063                        final int rowi = object / answerCols;
064                        final int coli = object - (rowi * answerCols);
065                        double dot = 0;
066                        for (int i = 0; i < a[rowi].length; i++) {
067                                dot += a[rowi][i] * b[i][coli];
068                        }
069                        ThreadedMatrixMulti.this.setAnswerElement(rowi, coli, dot);
070                }
071        }
072
073        public Matrix times(Matrix a, Matrix b) {
074                final double[][] ad = fromMatrix(a);
075                final double[][] bd = fromMatrix(b);
076                return this.times(ad, bd);
077        }
078
079        public Matrix times(double[][] a, double[][] b) {
080                this.a = a;
081                this.b = b;
082                this.answerCols = b[0].length;
083                this.answerRows = a.length;
084
085                if (this.answer != null) {
086                        if (!(this.answer[0].length == answerCols && this.answer.length == answerRows)) {
087                                this.answer = newAnswer(answerRows, answerCols);
088                        }
089                }
090                else {
091                        this.answer = newAnswer(answerRows, answerCols);
092                }
093
094                Parallel.forIndex(0, answerRows * answerCols, 1, new MultiplicationOperation());
095
096                return DenseMatrixFactoryMTJ.INSTANCE.copyArray(this.answer);
097        }
098
099        private static double[][] fromMatrix(Matrix a) {
100                final double[][] ret = new double[a.getNumRows()][a.getNumColumns()];
101                for (final MatrixEntry ds : a) {
102                        ret[ds.getRowIndex()][ds.getColumnIndex()] = ds.getValue();
103                }
104                return ret;
105        }
106
107        public void setAnswerElement(int rowi, int coli, double ans) {
108                this.answer[rowi][coli] = ans;
109        }
110
111        private double[][] newAnswer(int answerRows, int answerCols) {
112                return new double[answerRows][answerCols];
113        }
114
115        // public static void main(String[] args) {
116        // final int numRows = 1000;
117        // final int numColumns = 2000;
118        // final int repeat = 1;
119        // final DenseMatrix left =
120        // DenseMatrixFactoryMTJ.INSTANCE.createUniformRandom(numRows, numColumns,
121        // 0, 1,
122        // new Random(1));
123        // final DenseMatrix right =
124        // DenseMatrixFactoryMTJ.INSTANCE.createUniformRandom(numColumns, numRows,
125        // 0, 1,
126        // new Random(1));
127        // final double[][] leftd = fromMatrix(left);
128        // final double[][] rightd = fromMatrix(right);
129        // final ThreadedMatrixMulti tmm = new ThreadedMatrixMulti(numRows,
130        // numRows);
131        //
132        // final Timer t = Timer.timer();
133        // double correctDur = 0, threadDur = 0;
134        // for (int i = 0; i < repeat; i++) {
135        // t.start();
136        // final Matrix correct = left.times(right);
137        // correctDur += t.duration();
138        // t.start();
139        // final Matrix thread = tmm.times(leftd, rightd);
140        // threadDur += t.duration();
141        //
142        // }
143        // System.out.println("Correct took: " + correctDur / repeat);
144        // System.out.println("Threaded took: " + threadDur / repeat);
145        // }
146}