MatrixMultiplication.java


Below is the syntax highlighted version of MatrixMultiplication.java from §9.5 Numerical Solutions to Differential Equations.


/******************************************************************************
 *  Compilation:  javac MatrixMultiplication.java
 *  Execution:    java MatrixMultiplication
 *
 *  8 different ways to multiply two dense n-by-n matrices.
 *  Illustrates importance of row-major vs. column-major ordering.
 *
 *  % java MatrixMultiplication 512
 *  Generating input:  0.043 seconds
 *  Order ijk:   0.45 seconds
 *  Order ikj:   0.096 seconds
 *  Order jik:   0.303 seconds
 *  Order jki:   0.65 seconds
 *  Order kij:   0.102 seconds
 *  Order kji:   0.646 seconds
 *  Order jik JAMA optimized:   0.139 seconds
 *  Order ikj pure row:   0.041 seconds
 *
 *  These timings are on a Fujitsu CX2570 M2 server with dual,
 *  14-core 2.4GHz Intel Xeon E5 2680 v4 processors with 384GB RAM
 *  running the Springdale distribution of Linux.
 *
 ******************************************************************************/

public class MatrixMultiplication {
    public static void show(double[][] a) {
        int n = a.length;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                StdOut.printf("%6.4f ", a[i][j]);
            }
            StdOut.println();
        }
        StdOut.println();
    }


    public static void main(String[] args) {
        int n = Integer.parseInt(args[0]);
        long start, stop;
        double elapsed;


        // generate input
        start = System.currentTimeMillis();

        double[][] A = new double[n][n];
        double[][] B = new double[n][n];
        double[][] C;

        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                A[i][j] = StdRandom.uniformDouble(0.0, 1.0);

        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                B[i][j] = StdRandom.uniformDouble(0.0, 1.0);

        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Generating input:  " + elapsed + " seconds");

        // order 1: ijk = dot product version
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                for (int k = 0; k < n; k++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order ijk:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 2: ikj
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int i = 0; i < n; i++)
            for (int k = 0; k < n; k++)
                for (int j = 0; j < n; j++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order ikj:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 3: jik
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int j = 0; j < n; j++)
            for (int i = 0; i < n; i++)
                for (int k = 0; k < n; k++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order jik:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 4: jki = GAXPY version
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int j = 0; j < n; j++)
            for (int k = 0; k < n; k++)
                for (int i = 0; i < n; i++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order jki:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 5: kij
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int k = 0; k < n; k++)
            for (int i = 0; i < n; i++)
                for (int j = 0; j < n; j++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order kij:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 6: kji = outer product version
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int k = 0; k < n; k++)
            for (int j = 0; j < n; j++)
                for (int i = 0; i < n; i++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order kji:   " + elapsed + " seconds");
        if (n < 10) show(C);


        // order 7: jik optimized ala JAMA
        C = new double[n][n];
        start = System.currentTimeMillis();
        double[] bcolj = new double[n];
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < n; k++) bcolj[k] = B[k][j];
            for (int i = 0; i < n; i++) {
                double[] arowi = A[i];
                double sum = 0.0;
                for (int k = 0; k < n; k++) {
                    sum += arowi[k] * bcolj[k];
                }
                C[i][j] = sum;
            }
        }
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order jik JAMA optimized:   " + elapsed + " seconds");
        if (n < 10) show(C);

        // order 8: ikj pure row
        C = new double[n][n];
        start = System.currentTimeMillis();
        for (int i = 0; i < n; i++) {
            double[] arowi = A[i];
            double[] crowi = C[i];
            for (int k = 0; k < n; k++) {
                double[] browk = B[k];
                double aik = arowi[k];
                for (int j = 0; j < n; j++) {
                    crowi[j] += aik * browk[j];
                }
            }
        }
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        StdOut.println("Order ikj pure row:   " + elapsed + " seconds");
        if (n < 10) show(C);

    }

}


Copyright © 2000–2022, Robert Sedgewick and Kevin Wayne.
Last updated: Thu Aug 11 10:36:03 EDT 2022.