Simple Cache Optimization for GEMM

Function dgemm_simple_blk is supposed to compute the matrix product block-wise (compare notes from the lecture):

Exercise

Implement function dgemm_simple_blk:

#include <float.h>
#include <math.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/times.h>
#include <unistd.h>

void
initMatrix(size_t m, size_t n,
     double *A,
     ptrdiff_t incRow, ptrdiff_t incCol)
{
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            A[i*incRow+j*incCol] = i*n+j+1;
        }
    }
}

void
printMatrix(size_t m, size_t n,
            const double *A,
            ptrdiff_t incRow, ptrdiff_t incCol)
{
    printf("\n");
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            printf("%10.3lf", A[i*incRow+j*incCol]);
        }
        printf("\n");
    }
    printf("\n");
}

//-- BLAS Level 1 --------------------------------------------------------------

size_t
idamax(size_t n, const double *x, ptrdiff_t incX)
{
    size_t iMax = 0;
    for (size_t i=1; i<n; ++i) {
        if (fabs(x[i*incX]) > fabs(x[iMax*incX])) {
            iMax = i;
        }
    }
    return iMax;
}


void
dcopy(size_t n,
      const double *x, ptrdiff_t incX,
      double *y, ptrdiff_t incY)
{
    for (size_t i=0; i<n; ++i) {
        y[i*incY] = x[i*incX];
    }
}

void
dswap(size_t n,
      double *x, ptrdiff_t incX,
      double *y, ptrdiff_t incY)
{
    for (size_t i=0; i<n; ++i) {
        double tmp = x[i*incX];
        x[i*incX] = y[i*incY];
        y[i*incY] = tmp;
    }
}

void
daxpy(size_t n, double alpha,
      const double *x, ptrdiff_t incX,
      double *y, ptrdiff_t incY)
{
    if (alpha==0) {
        return;
    }
    for (size_t i=0; i<n; ++i) {
        y[i*incY] += alpha*x[i*incX];
    }
}

void
dscal(size_t n, double alpha,
      double *x, ptrdiff_t incX)
{
    if (alpha==1) {
        return;
    }
    if (alpha==0) {
        for (size_t i=0; i<n; ++i) {
            x[i*incX] = 0;
        }
    } else {
        for (size_t i=0; i<n; ++i) {
            x[i*incX] *= alpha;
        }
    }
}

double
ddot(size_t n,
     const double *x, ptrdiff_t incX,
     const double *y, ptrdiff_t incY)
{
    double alpha = 0;
    for (size_t i=0; i<n; ++i) {
        alpha += x[i*incX]*y[i*incY];
    }
    return alpha;
}

//-- BLAS Level 2 --------------------------------------------------------------

void
dger(size_t m, size_t n, double alpha,
     const double *x, ptrdiff_t incX,
     const double *y, ptrdiff_t incY,
     double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
    if (alpha==0 || m==0 || n==0) {
        return;
    }
    if (incRowA<incColA) {
        // col major case
        for (size_t j=0; j<n; ++j) {
            daxpy(m, alpha*y[j*incY], x, incX, &A[j*incColA], incRowA);
        }
    } else {
        // row major case
        for (size_t i=0; i<m; ++i) {
            daxpy(n, alpha*x[i*incX], y, incY, &A[i*incRowA], incColA);
        }
    }
}

//-- BLAS Level 3 Auxiliary Functions ------------------------------------------

void
dgecopy(size_t m, size_t n,
        const double *X, ptrdiff_t incRowX, ptrdiff_t incColX,
        double *Y, ptrdiff_t incRowY, ptrdiff_t incColY)
{
    if (incRowX<incColX) {
        for (size_t j=0; j<n; ++j) {
            dcopy(m, &X[j*incColX], incRowX, &Y[j*incColY], incRowY);
        }
    } else {
        for (size_t i=0; i<m; ++i) {
            dcopy(n, &X[i*incRowX], incColX, &Y[i*incRowY], incColY);
        }
    }
}

void
dgescal(size_t m, size_t n,
        double alpha,
        double *X, ptrdiff_t incRowX, ptrdiff_t incColX)
{
    if (incRowX<incColX) {
        for (size_t j=0; j<n; ++j) {
            dscal(m, alpha, &X[j*incColX], incRowX);
        }
    } else {
        for (size_t i=0; i<m; ++i) {
            dscal(n, alpha, &X[i*incRowX], incColX);
        }
    }
}

void
dgeaxpy(size_t m, size_t n, double alpha,
        const double *X, ptrdiff_t incRowX, ptrdiff_t incColX,
        double *Y, ptrdiff_t incRowY, ptrdiff_t incColY)
{
    if (alpha==0 || m==0 || n==0) {
        return;
    }
    if (incRowX<incColX) {
        for (size_t j=0; j<n; ++j) {
            daxpy(m, alpha, &X[j*incColX], incRowX, &Y[j*incColY], incRowY);
        }
    } else {
        for (size_t i=0; i<m; ++i) {
            daxpy(n, alpha, &X[i*incRowX], incColX, &Y[i*incRowY], incColY);
        }
    }
}

//-- BLAS Level 3 --------------------------------------------------------------

void
dgemm_jil(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t j=0; j<n; ++j) {
        for (size_t i=0; i<m; ++i) {
            for (size_t l=0; l<k; ++l) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

void
dgemm_jli(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t j=0; j<n; ++j) {
        for (size_t l=0; l<k; ++l) {
            for (size_t i=0; i<m; ++i) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

void
dgemm_ijl(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            for (size_t l=0; l<k; ++l) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

void
dgemm_ilj(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t i=0; i<m; ++i) {
        for (size_t l=0; l<k; ++l) {
            for (size_t j=0; j<n; ++j) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

void
dgemm_lji(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t l=0; l<k; ++l) {
        for (size_t j=0; j<n; ++j) {
            for (size_t i=0; i<m; ++i) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

void
dgemm_lij(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    dgescal(m, n, beta, C, incRowC, incColC);
    if (alpha==0) {
        return;
    }
    for (size_t l=0; l<k; ++l) {
        for (size_t i=0; i<m; ++i) {
            for (size_t j=0; j<n; ++j) {
                C[i*incRowC+j*incColC] +=
                    alpha*A[i*incRowA+l*incColA]*B[l*incRowB+j*incColB];
            }
        }
    }
}

#ifndef DGEMM_BLK_M
#define DGEMM_BLK_M 16
#endif

#ifndef DGEMM_BLK_N
#define DGEMM_BLK_N 16
#endif

#ifndef DGEMM_BLK_K
#define DGEMM_BLK_K 128
#endif

void
dgemm_simple_blk(size_t m, size_t n, size_t k,
                 double alpha,
                 const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
                 const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
                 double beta,
                 double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    // .... YOUR CODE HERE ....
}

//-- Function for benchmarking and testing -------------------------------------

double
walltime()
{
   struct tms    ts;
   static double ClockTick=0.0;

   if (ClockTick==0.0) {
        ClockTick = 1.0 / ((double) sysconf(_SC_CLK_TCK));
   }
   return ((double) times(&ts)) * ClockTick;
}

void
randGeMatrix(size_t m, size_t n, double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
    for (size_t j=0; j<n; ++j) {
        for (size_t i=0; i<m; ++i) {
            A[i*incRowA+j*incColA] = ((double)rand()-RAND_MAX/2)*200/RAND_MAX;
        }
    }
}


void
dgemm_ref(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    size_t i, j, l;

    if (beta!=1) {
        for (i=0; i<m; ++i) {
            for (j=0; j<n; ++j) {
                C[i*incRowC+j*incColC] *= beta;
            }
        }
    }
    if (alpha!=0) {
        for (i=0; i<m; ++i) {
            for (j=0; j<n; ++j) {
                for (l=0; l<k; ++l) {
                    C[i*incRowC+j*incColC] += alpha*A[i*incRowA+l*incColA]
                                                   *B[l*incRowB+j*incColB];
                }
            }
        }
    }
}

double
err_lu(size_t m, size_t n,
       const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
       double *LU, ptrdiff_t incRowLU, ptrdiff_t incColLU,
       size_t *P, ptrdiff_t incP)
{
    double aNrm1 = 0;
    double err   = 0;

    size_t  i,j;
    size_t  k     = (m<n) ? m : n;
    double *L    = (double *)malloc(m*k*sizeof(double));
    double *U    = (double *)malloc(k*n*sizeof(double));

    for (j=0; j<k; ++j) {
        for (i=0; i<m; ++i) {
            L[i+j*m] = (i>j) ? LU[i*incRowLU+j*incColLU]
                             : (i==j) ? 1
                                      : 0;
        }
    }
    for (j=0; j<n; ++j) {
        for (i=0; i<k; ++i) {
            U[i+j*k] = (i<=j) ? LU[i*incRowLU+j*incColLU]
                              : 0;
        }
    }
    for (j=0; j<n; ++j) {
        for (i=0; i<m; ++i) {
            aNrm1 += fabs(A[i*incRowA+j*incColA]);
        }
    }

    dgemm_ref(m, n, k, 1.0,
              L, 1, m,
              U, 1, k,
              0.0,
              LU, incRowLU, incColLU);

    for (size_t i=m; i>0; --i) {
        if (i-1 != P[(i-1)*incP]) {
            dswap(n,
                  &LU[(i-1)*incRowLU], incColLU,
                  &LU[P[(i-1)*incP]*incRowLU], incColLU);
        }
    }

    for (j=0; j<n; ++j) {
        for (i=0; i<m; ++i) {
            err += fabs(A[i*incRowA+j*incColA]-LU[i*incRowLU+j*incColLU]);
        }
    }
    err /= (aNrm1*k);

    free(L);
    free(U);
    return err;
}

#define MIN(X,Y)   ((X)<(Y) ? (X) : (Y))
#define MAX(X,Y)   ((X)>(Y) ? (X) : (Y))

double
dgenrm1(size_t m, size_t n, const double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
    double  result = 0;

    for (size_t j=0; j<n; ++j) {
        double sum = 0;
        for (size_t i=0; i<m; ++i) {
            sum += fabs(A[i*incRowA+j*incColA]);
        }
        if (sum>result) {
            result = sum;
        }
    }
    return result;
}

double
err_dgemm(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          const double *C0, ptrdiff_t incRowC0, ptrdiff_t incColC0,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    double normA = dgenrm1(m, k, A, incRowA, incColA);
    double normB = dgenrm1(k, n, B, incRowB, incColB);
    double normC = dgenrm1(m, n, C, incRowC0, incColC0);
    double normD;
    size_t mn    = (m>n)  ? m  : n;
    size_t mnk   = (mn>k) ? mn : k;

    normA = MAX(normA, fabs(alpha)*normA);
    normC = MAX(normC, fabs(beta)*normC);

    dgeaxpy(m, n, -1.0, C0, incRowC0, incColC0, C, incRowC, incColC);
    normD = dgenrm1(m, n, C, incRowC, incColC);

    return normD/(mnk*normA*normB*normC);
}

//------------------------------------------------------------------------------

#ifndef MIN_N
#define MIN_N 100
#endif

#ifndef MAX_N
#define MAX_N 4000
#endif

#ifndef INC_N
#define INC_N 100
#endif

#ifndef MIN_M
#define MIN_M 100
#endif

#ifndef MAX_M
#define MAX_M 4000
#endif

#ifndef INC_M
#define INC_M 100
#endif

#ifndef MIN_K
#define MIN_K 100
#endif

#ifndef MAX_K
#define MAX_K 4000
#endif

#ifndef INC_K
#define INC_K 100
#endif

#ifndef ALPHA
#define ALPHA 1
#endif

#ifndef BETA
#define BETA 1
#endif

#ifndef ROWMAJOR_A
#define ROWMAJOR_A 0
#endif

#ifndef ROWMAJOR_B
#define ROWMAJOR_B 0
#endif

#ifndef ROWMAJOR_C
#define ROWMAJOR_C 0
#endif

double A_[MAX_M*MAX_K];
double B_[MAX_K*MAX_N];
double C_[MAX_M*MAX_N];

double C0[MAX_M*MAX_N];     // reference solution
double C1[MAX_M*MAX_N];     // tested solution

int
main()
{
    randGeMatrix(MAX_M, MAX_K, A_, 1, MAX_M);
    randGeMatrix(MAX_K, MAX_N, B_, 1, MAX_K);
    randGeMatrix(MAX_N, MAX_M, C_, 1, MAX_M);

    printf("#%9s %9s %9s", "m", "n", "k");
    printf(" %12s %12s %17s", "t", "MFLOPS", "Residual Error");
    printf(" %12s %12s %17s", "t", "MFLOPS", "Residual Error");
    printf("\n");

    for (size_t m=MIN_M, n=MIN_N, k=MIN_K; n<=MAX_N && m<=MAX_M && k<=MAX_K;
         m+=INC_M, n+=INC_N, k+=INC_K)
    {
        double t, dt, err;
        size_t runs  = 1;
        double ops   = 2.0*m/1000*n/1000*k;

        ptrdiff_t incRowA = (ROWMAJOR_A==1) ? k : 1;
        ptrdiff_t incColA = (ROWMAJOR_A==1) ? 1 : m;

        ptrdiff_t incRowB = (ROWMAJOR_B==1) ? n : 1;
        ptrdiff_t incColB = (ROWMAJOR_B==1) ? 1 : k;

        ptrdiff_t incRowC = (ROWMAJOR_C==1) ? n : 1;
        ptrdiff_t incColC = (ROWMAJOR_C==1) ? 1 : m;

        printf(" %9zu %9zu %9td", m, n, k);

        // compute reference solution
        dgecopy(m, n, C_, 1, MAX_M, C0, incRowC, incColC);
        dgemm_ref(m, n, k,
                  ALPHA,
                  A_, incRowA, incRowA,
                  B_, incRowB, incColB,
                  BETA,
                  C0, incRowC, incColC);

        // benchmark dgemm_jli
        t    = 0;
        runs = 0;
        do {
            dgecopy(m, n, C_, 1, MAX_M, C1, incRowC, incColC);
            dt = walltime();
            dgemm_jli(m, n, k,
                      ALPHA,
                      A_, incRowA, incRowA,
                      B_, incRowB, incColB,
                      BETA,
                      C1, incRowC, incColC);
            dt = walltime() - dt;
            t += dt;
            ++runs;
        } while (t<0.3);
        t /= runs;

        err = err_dgemm(m, n, k,
                        ALPHA,
                        A_, incRowA, incColA,
                        B_, incRowB, incColB,
                        BETA,
                        C0, incRowC, incColC,
                        C1, incRowC, incColC);

        printf(" %12.2e %12.2lf %12.2e %4s", t, ops/t,
               err, (err<DBL_EPSILON) ? "PASS" : "FAIL");

        // benchmark dgemm_simple_blk
        t    = 0;
        runs = 0;
        do {
            dgecopy(m, n, C_, 1, MAX_M, C1, incRowC, incColC);
            dt = walltime();
            dgemm_simple_blk(m, n, k,
                             ALPHA,
                             A_, incRowA, incRowA,
                             B_, incRowB, incColB,
                             BETA,
                             C1, incRowC, incColC);
            dt = walltime() - dt;
            t += dt;
            ++runs;
        } while (t<0.3);
        t /= runs;

        err = err_dgemm(m, n, k,
                        ALPHA,
                        A_, incRowA, incColA,
                        B_, incRowB, incColB,
                        BETA,
                        C0, incRowC, incColC,
                        C1, incRowC, incColC);

        printf(" %12.2e %12.2lf %12.2e %4s", t, ops/t,
               err, (err<DBL_EPSILON) ? "PASS" : "FAIL");

        printf("\n");
    }

    return 0;
}