1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.ml.benchmark;
31
32 import java.util.Random;
33
34 import org.openimaj.math.matrix.CFMatrixUtils;
35 import org.openimaj.math.matrix.MeanVector;
36 import org.openimaj.time.Timer;
37
38 import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
39 import gov.sandia.cognition.math.matrix.mtj.SparseColumnMatrix;
40 import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
41 import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
42 import gov.sandia.cognition.math.matrix.mtj.SparseRowMatrix;
43
44
45
46
47
48 public class CFMatrixMultiplyBenchmark {
49
50 public static void main(String[] args) {
51 SparseMatrix a = SparseMatrixFactoryMTJ.INSTANCE.copyMatrix(SparseMatrixFactoryMTJ.INSTANCE.createWrapper(new FlexCompRowMatrix(4, 1118)));
52 CFMatrixUtils.plusInplace(a, 1);
53 SparseRowMatrix xtrow = CFMatrixUtils.randomSparseRow(1118,22917,0d,1d,1 - 0.9998818947086253, new Random(1));
54 SparseColumnMatrix xtcol = CFMatrixUtils.randomSparseCol(1118,22917,0d,1d,1 - 0.9998818947086253, new Random(1));
55
56 System.out.println("xtrow sparsity: " + CFMatrixUtils.sparsity(xtrow));
57 System.out.println("xtcol sparsity: " + CFMatrixUtils.sparsity(xtcol));
58 System.out.println("Equal: " + CFMatrixUtils.fastsparsedot(a,xtcol).equals(a.times(xtcol), 0));
59 MeanVector mv = new MeanVector();
60 System.out.println("doing: a . xtcol");
61 for (int i = 0; i < 10; i++) {
62 Timer t = Timer.timer();
63 CFMatrixUtils.fastsparsedot(a,xtcol);
64 mv.update(new double[]{t.duration()});
65 System.out.println("time: " + mv.vec()[0]);
66 }
67
68
69 mv.reset();
70 System.out.println("doing: a . xtcol");
71 for (int i = 0; i < 10; i++) {
72 Timer t = Timer.timer();
73 a.times(xtcol);
74 mv.update(new double[]{t.duration()});
75 System.out.println("time: " + mv.vec()[0]);
76 }
77 mv.reset();
78 System.out.println("doing: a . xtrow");
79 for (int i = 0; i < 10; i++) {
80 Timer t = Timer.timer();
81 a.times(xtrow);
82 mv.update(new double[]{t.duration()});
83 System.out.println("time: " + mv.vec()[0]);
84 }
85 }
86 }