/****************************************************************************** * Compilation: javac MatrixMultiplication.java * Execution: java MatrixMultiplication * * 8 different ways to multiply two dense n-by-n matrices. * Illustrates importance of row-major vs. column-major ordering. * * % java MatrixMultiplication 512 * Generating input: 0.043 seconds * Order ijk: 0.45 seconds * Order ikj: 0.096 seconds * Order jik: 0.303 seconds * Order jki: 0.65 seconds * Order kij: 0.102 seconds * Order kji: 0.646 seconds * Order jik JAMA optimized: 0.139 seconds * Order ikj pure row: 0.041 seconds * * These timings are on a Fujitsu CX2570 M2 server with dual, * 14-core 2.4GHz Intel Xeon E5 2680 v4 processors with 384GB RAM * running the Springdale distribution of Linux. * ******************************************************************************/ public class MatrixMultiplication { public static void show(double[][] a) { int n = a.length; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { StdOut.printf("%6.4f ", a[i][j]); } StdOut.println(); } StdOut.println(); } public static void main(String[] args) { int n = Integer.parseInt(args[0]); long start, stop; double elapsed; // generate input start = System.currentTimeMillis(); double[][] A = new double[n][n]; double[][] B = new double[n][n]; double[][] C; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) A[i][j] = StdRandom.uniformDouble(0.0, 1.0); for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) B[i][j] = StdRandom.uniformDouble(0.0, 1.0); stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Generating input: " + elapsed + " seconds"); // order 1: ijk = dot product version C = new double[n][n]; start = System.currentTimeMillis(); for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < n; k++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order ijk: " + elapsed + " seconds"); if (n < 10) show(C); // order 2: ikj C = new double[n][n]; start = System.currentTimeMillis(); for (int i = 0; i < n; i++) for (int k = 0; k < n; k++) for (int j = 0; j < n; j++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order ikj: " + elapsed + " seconds"); if (n < 10) show(C); // order 3: jik C = new double[n][n]; start = System.currentTimeMillis(); for (int j = 0; j < n; j++) for (int i = 0; i < n; i++) for (int k = 0; k < n; k++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order jik: " + elapsed + " seconds"); if (n < 10) show(C); // order 4: jki = GAXPY version C = new double[n][n]; start = System.currentTimeMillis(); for (int j = 0; j < n; j++) for (int k = 0; k < n; k++) for (int i = 0; i < n; i++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order jki: " + elapsed + " seconds"); if (n < 10) show(C); // order 5: kij C = new double[n][n]; start = System.currentTimeMillis(); for (int k = 0; k < n; k++) for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order kij: " + elapsed + " seconds"); if (n < 10) show(C); // order 6: kji = outer product version C = new double[n][n]; start = System.currentTimeMillis(); for (int k = 0; k < n; k++) for (int j = 0; j < n; j++) for (int i = 0; i < n; i++) C[i][j] += A[i][k] * B[k][j]; stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order kji: " + elapsed + " seconds"); if (n < 10) show(C); // order 7: jik optimized ala JAMA C = new double[n][n]; start = System.currentTimeMillis(); double[] bcolj = new double[n]; for (int j = 0; j < n; j++) { for (int k = 0; k < n; k++) bcolj[k] = B[k][j]; for (int i = 0; i < n; i++) { double[] arowi = A[i]; double sum = 0.0; for (int k = 0; k < n; k++) { sum += arowi[k] * bcolj[k]; } C[i][j] = sum; } } stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order jik JAMA optimized: " + elapsed + " seconds"); if (n < 10) show(C); // order 8: ikj pure row C = new double[n][n]; start = System.currentTimeMillis(); for (int i = 0; i < n; i++) { double[] arowi = A[i]; double[] crowi = C[i]; for (int k = 0; k < n; k++) { double[] browk = B[k]; double aik = arowi[k]; for (int j = 0; j < n; j++) { crowi[j] += aik * browk[j]; } } } stop = System.currentTimeMillis(); elapsed = (stop - start) / 1000.0; StdOut.println("Order ikj pure row: " + elapsed + " seconds"); if (n < 10) show(C); } }