001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.math.matrix; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.MatrixEntry; 034import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ; 035 036import org.openimaj.util.function.Operation; 037import org.openimaj.util.parallel.Parallel; 038 039/** 040 * Perform a multithreaded matrix multiplication 041 * 042 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 043 * 044 */ 045public class ThreadedMatrixMulti { 046 private double[][] answer; 047 private double[][] a; 048 private double[][] b; 049 private int answerCols; 050 private int answerRows; 051 052 public ThreadedMatrixMulti() { 053 054 } 055 056 public ThreadedMatrixMulti(int numRows, int numCols) { 057 this.newAnswer(numRows, numCols); 058 } 059 060 class MultiplicationOperation implements Operation<Integer> { 061 @Override 062 public void perform(Integer object) { 063 final int rowi = object / answerCols; 064 final int coli = object - (rowi * answerCols); 065 double dot = 0; 066 for (int i = 0; i < a[rowi].length; i++) { 067 dot += a[rowi][i] * b[i][coli]; 068 } 069 ThreadedMatrixMulti.this.setAnswerElement(rowi, coli, dot); 070 } 071 } 072 073 public Matrix times(Matrix a, Matrix b) { 074 final double[][] ad = fromMatrix(a); 075 final double[][] bd = fromMatrix(b); 076 return this.times(ad, bd); 077 } 078 079 public Matrix times(double[][] a, double[][] b) { 080 this.a = a; 081 this.b = b; 082 this.answerCols = b[0].length; 083 this.answerRows = a.length; 084 085 if (this.answer != null) { 086 if (!(this.answer[0].length == answerCols && this.answer.length == answerRows)) { 087 this.answer = newAnswer(answerRows, answerCols); 088 } 089 } 090 else { 091 this.answer = newAnswer(answerRows, answerCols); 092 } 093 094 Parallel.forIndex(0, answerRows * answerCols, 1, new MultiplicationOperation()); 095 096 return DenseMatrixFactoryMTJ.INSTANCE.copyArray(this.answer); 097 } 098 099 private static double[][] fromMatrix(Matrix a) { 100 final double[][] ret = new double[a.getNumRows()][a.getNumColumns()]; 101 for (final MatrixEntry ds : a) { 102 ret[ds.getRowIndex()][ds.getColumnIndex()] = ds.getValue(); 103 } 104 return ret; 105 } 106 107 public void setAnswerElement(int rowi, int coli, double ans) { 108 this.answer[rowi][coli] = ans; 109 } 110 111 private double[][] newAnswer(int answerRows, int answerCols) { 112 return new double[answerRows][answerCols]; 113 } 114 115 // public static void main(String[] args) { 116 // final int numRows = 1000; 117 // final int numColumns = 2000; 118 // final int repeat = 1; 119 // final DenseMatrix left = 120 // DenseMatrixFactoryMTJ.INSTANCE.createUniformRandom(numRows, numColumns, 121 // 0, 1, 122 // new Random(1)); 123 // final DenseMatrix right = 124 // DenseMatrixFactoryMTJ.INSTANCE.createUniformRandom(numColumns, numRows, 125 // 0, 1, 126 // new Random(1)); 127 // final double[][] leftd = fromMatrix(left); 128 // final double[][] rightd = fromMatrix(right); 129 // final ThreadedMatrixMulti tmm = new ThreadedMatrixMulti(numRows, 130 // numRows); 131 // 132 // final Timer t = Timer.timer(); 133 // double correctDur = 0, threadDur = 0; 134 // for (int i = 0; i < repeat; i++) { 135 // t.start(); 136 // final Matrix correct = left.times(right); 137 // correctDur += t.duration(); 138 // t.start(); 139 // final Matrix thread = tmm.times(leftd, rightd); 140 // threadDur += t.duration(); 141 // 142 // } 143 // System.out.println("Correct took: " + correctDur / repeat); 144 // System.out.println("Threaded took: " + threadDur / repeat); 145 // } 146}