Blocked GEMM with AVX Micro Kernel

#include <float.h>
#include <math.h>
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <sys/times.h>
#include <unistd.h>

void
initGeMatrix(size_t m, size_t n,
             double *A,
             ptrdiff_t incRowA, ptrdiff_t incColA)
{
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            A[i*incRowA+j*incColA] = i*n + j + 1;
        }
    }
}

void
printGeMatrix(size_t m, size_t n,
              const double *A,
              ptrdiff_t incRowA, ptrdiff_t incColA)
{
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            printf("%11.4lf ", A[i*incRowA+j*incColA]);
        }
        printf("\n");
    }
    printf("\n");
}

void
dgemm_ref(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    size_t i, j, l;

    if (beta!=1) {
        if (beta!=0) {
            for (i=0; i<m; ++i) {
                for (j=0; j<n; ++j) {
                    C[i*incRowC+j*incColC] *= beta;
                }
            }
        } else {
            for (i=0; i<m; ++i) {
                for (j=0; j<n; ++j) {
                    C[i*incRowC+j*incColC] = 0;
                }
            }
        }
    }
    if (alpha!=0) {
        for (i=0; i<m; ++i) {
            for (j=0; j<n; ++j) {
                for (l=0; l<k; ++l) {
                    C[i*incRowC+j*incColC] += alpha*A[i*incRowA+l*incColA]
                                                   *B[l*incRowB+j*incColB];
                }
            }
        }
    }
}


#ifndef DGEMM_MR
#define DGEMM_MR    4
#endif

#ifndef DGEMM_NR
#define DGEMM_NR    8
#endif

#ifndef DGEMM_MC
#define DGEMM_MC    256
#endif

#ifndef DGEMM_NC
#define DGEMM_NC    256
#endif

#ifndef DGEMM_KC
#define DGEMM_KC    512
#endif

void *malloc_aligned(size_t size, size_t alignment)
{
    size     += alignment;

    void *ptr  = malloc(size);
    void *ptr2 = (void *)(((size_t)ptr + alignment) & ~(alignment-1));
    void **vp  = (void**) ptr2 - 1;
    *vp        = ptr;
    return ptr2;
}

void
free_aligned(void *ptr)
{
    free(*((void**)ptr-1));
}

void
dgepack_A(size_t m, size_t k,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          double *p)
{
    size_t mb = (m+DGEMM_MR-1)/DGEMM_MR;

    for (size_t l=0; l<k; ++l) {
        for (size_t i1=0; i1<mb; ++i1) {
            for (size_t i0=0; i0<DGEMM_MR; ++i0) {
                size_t i  = i1*DGEMM_MR + i0;
                size_t nu = i1*DGEMM_MR*k + l*DGEMM_MR + i0;
                p[nu] = (i<m) ? A[i*incRowA + l*incColA]
                              : 0;
            }
        }
    }
}

void
dgepack_B(size_t k, size_t n,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double *p)
{
    size_t nb = (n+DGEMM_NR-1)/DGEMM_NR;

    for (size_t j1=0; j1<nb; ++j1) {
        for (size_t j0=0; j0<DGEMM_NR; ++j0) {
            for (size_t l=0; l<k; ++l) {
                size_t j  = j1*DGEMM_NR + j0;
                size_t nu = j1*DGEMM_NR*k + l*DGEMM_NR + j0;
                p[nu] = (j<n) ? B[l*incRowB + j*incColB]
                              : 0;
            }
        }
    }
}

#if DGEMM_MR==4 && DGEMM_NR==8

void
dgemm_micro(size_t kc_, double alpha,
            const double *A, const double *B,
            double beta,
            double *C, ptrdiff_t incRowC_, ptrdiff_t incColC_)
{
    int64_t kc      = kc_;
    int64_t incRowC = incRowC_;
    int64_t incColC = incColC_;

    double *pAlpha  = &alpha;
    double *pBeta   = &beta;

//
//  Compute AB = A*B
//
    __asm__ volatile
    (
    "movq      %0,           %%rdi    \n\t"  // kc
    "movq      %1,           %%rsi    \n\t"  // A
    "movq      %2,           %%rdx    \n\t"  // B
    "movq      %5,           %%rcx    \n\t"  // C
    "movq      %6,           %%r8     \n\t"  // incRowC
    "movq      %7,           %%r9     \n\t"  // incColC

    "vmovapd           0 * 32(%%rdx),         %%ymm4\n\t"

    "vbroadcastsd       0 * 8(%%rsi),         %%ymm0\n\t"
    "vbroadcastsd       1 * 8(%%rsi),         %%ymm1\n\t"
    "vbroadcastsd       2 * 8(%%rsi),         %%ymm2\n\t"
    "vbroadcastsd       3 * 8(%%rsi),         %%ymm3\n\t"

    "vxorpd                  %%ymm8,          %%ymm8,          %%ymm8\n\t"
    "vxorpd                  %%ymm9,          %%ymm9,          %%ymm9\n\t"
    "vxorpd                  %%ymm10,         %%ymm10,         %%ymm10\n\t"
    "vxorpd                  %%ymm11,         %%ymm11,         %%ymm11\n\t"
    "vxorpd                  %%ymm12,         %%ymm12,         %%ymm12\n\t"
    "vxorpd                  %%ymm13,         %%ymm13,         %%ymm13\n\t"
    "vxorpd                  %%ymm14,         %%ymm14,         %%ymm14\n\t"
    "vxorpd                  %%ymm15,         %%ymm15,         %%ymm15\n\t"

    "jmp                     check%=\n\t"

    "loop%=:\n\t"

    "vmovapd           1 * 32(%%rdx),         %%ymm5\n\t"

    "vmulpd                  %%ymm0,          %%ymm4,          %%ymm6\n\t"
    "vaddpd                  %%ymm6,          %%ymm8,          %%ymm8\n\t"
    "vmulpd                  %%ymm1,          %%ymm4,          %%ymm7\n\t"
    "vaddpd                  %%ymm7,          %%ymm9,          %%ymm9\n\t"
    "vmulpd                  %%ymm2,          %%ymm4,          %%ymm6\n\t"
    "vaddpd                  %%ymm6,          %%ymm10,         %%ymm10\n\t"
    "vmulpd                  %%ymm3,          %%ymm4,          %%ymm7\n\t"
    "vaddpd                  %%ymm7,          %%ymm11,         %%ymm11\n\t"

    "vmovapd           2 * 32(%%rdx),         %%ymm4\n\t"

    "vmulpd                  %%ymm0,          %%ymm5,          %%ymm6\n\t"
    "vaddpd                  %%ymm6,          %%ymm12,         %%ymm12\n\t"
    "vbroadcastsd       4 * 8(%%rsi),         %%ymm0\n\t"
    "vmulpd                  %%ymm1,          %%ymm5,          %%ymm7\n\t"
    "vaddpd                  %%ymm7,          %%ymm13,         %%ymm13\n\t"
    "vbroadcastsd       5 * 8(%%rsi),         %%ymm1\n\t"
    "vmulpd                  %%ymm2,          %%ymm5,          %%ymm6\n\t"
    "vaddpd                  %%ymm6,          %%ymm14,         %%ymm14\n\t"
    "vbroadcastsd       6 * 8(%%rsi),         %%ymm2\n\t"
    "vmulpd                  %%ymm3,          %%ymm5,          %%ymm7\n\t"
    "vaddpd                  %%ymm7,          %%ymm15,         %%ymm15\n\t"
    "vbroadcastsd       7 * 8(%%rsi),         %%ymm3\n\t"

    "addq                    $32,            %%rsi\n\t"
    "addq                    $2*32,          %%rdx\n\t"
    "decq                    %%rdi\n\t"

    "check%=:\n\t"
    "testq                   %%rdi,           %%rdi\n\t"
    "jg                      loop%=\n\t"

    "movq      %3,           %%rdi                  \n\t"  // alpha
    "movq      %4,           %%rsi                  \n\t"  // beta
    "vbroadcastsd           (%%rdi),          %%ymm6\n\t"
    "vbroadcastsd           (%%rsi),          %%ymm7\n\t"


    "vmulpd                  %%ymm6,          %%ymm8,          %%ymm8\n\t"
    "vmulpd                  %%ymm6,          %%ymm9,          %%ymm9\n\t"
    "vmulpd                  %%ymm6,          %%ymm10,         %%ymm10\n\t"
    "vmulpd                  %%ymm6,          %%ymm11,         %%ymm11\n\t"
    "vmulpd                  %%ymm6,          %%ymm12,         %%ymm12\n\t"
    "vmulpd                  %%ymm6,          %%ymm13,         %%ymm13\n\t"
    "vmulpd                  %%ymm6,          %%ymm14,         %%ymm14\n\t"
    "vmulpd                  %%ymm6,          %%ymm15,         %%ymm15\n\t"

    "leaq                    (,%%r8,8),       %%r8\n\t"
    "leaq                    (,%%r9,8),       %%r9\n\t"

    "leaq                    (,%%r9,2),       %%r10\n\t"
    "leaq                    (%%r10,%%r9),    %%r11\n\t"
    "leaq                    (%%rcx,%%r10,2), %%rdx\n\t"

    "#\n\t"
    "#       Update C(0,:)\n\t"
    "#\n\t"
    "vmovlpd                 (%%rcx),         %%xmm0,          %%xmm0\n\t"
    "vmovhpd                 (%%rcx,%%r9),    %%xmm0,          %%xmm0\n\t"
    "vmovlpd                 (%%rcx,%%r10),   %%xmm1,          %%xmm1\n\t"
    "vmovhpd                 (%%rcx,%%r11),   %%xmm1,          %%xmm1\n\t"
    "vmovlpd                 (%%rdx),         %%xmm2,          %%xmm2\n\t"
    "vmovhpd                 (%%rdx,%%r9),    %%xmm2,          %%xmm2\n\t"
    "vmovlpd                 (%%rdx,%%r10),   %%xmm3,          %%xmm3\n\t"
    "vmovhpd                 (%%rdx,%%r11),   %%xmm3,          %%xmm3\n\t"

    "vmulpd                  %%xmm7,          %%xmm0,          %%xmm0\n\t"
    "vmulpd                  %%xmm7,          %%xmm1,          %%xmm1\n\t"
    "vmulpd                  %%xmm7,          %%xmm2,          %%xmm2\n\t"
    "vmulpd                  %%xmm7,          %%xmm3,          %%xmm3\n\t"

    "vextractf128            $1,              %%ymm8,          %%xmm4\n\t"
    "vextractf128            $1,              %%ymm12,         %%xmm5\n\t"

    "vaddpd                  %%xmm0,          %%xmm8,          %%xmm0\n\t"
    "vaddpd                  %%xmm1,          %%xmm4,          %%xmm1\n\t"
    "vaddpd                  %%xmm2,          %%xmm12,         %%xmm2\n\t"
    "vaddpd                  %%xmm3,          %%xmm5,          %%xmm3\n\t"

    "vmovlpd                 %%xmm0,          (%%rcx)\n\t"
    "vmovhpd                 %%xmm0,          (%%rcx,%%r9)\n\t"
    "vmovlpd                 %%xmm1,          (%%rcx,%%r10)\n\t"
    "vmovhpd                 %%xmm1,          (%%rcx,%%r11)\n\t"
    "vmovlpd                 %%xmm2,          (%%rdx)\n\t"
    "vmovhpd                 %%xmm2,          (%%rdx,%%r9)\n\t"
    "vmovlpd                 %%xmm3,          (%%rdx,%%r10)\n\t"
    "vmovhpd                 %%xmm3,          (%%rdx,%%r11)\n\t"

    "#\n\t"
    "#       Update C(1,:)\n\t"
    "#\n\t"
    "addq                    %%r8,            %%rcx\n\t"
    "addq                    %%r8,            %%rdx\n\t"

    "vmovlpd                 (%%rcx),         %%xmm0,          %%xmm0\n\t"
    "vmovhpd                 (%%rcx,%%r9),    %%xmm0,          %%xmm0\n\t"
    "vmovlpd                 (%%rcx,%%r10),   %%xmm1,          %%xmm1\n\t"
    "vmovhpd                 (%%rcx,%%r11),   %%xmm1,          %%xmm1\n\t"
    "vmovlpd                 (%%rdx),         %%xmm2,          %%xmm2\n\t"
    "vmovhpd                 (%%rdx,%%r9),    %%xmm2,          %%xmm2\n\t"
    "vmovlpd                 (%%rdx,%%r10),   %%xmm3,          %%xmm3\n\t"
    "vmovhpd                 (%%rdx,%%r11),   %%xmm3,          %%xmm3\n\t"

    "vmulpd                  %%xmm7,          %%xmm0,          %%xmm0\n\t"
    "vmulpd                  %%xmm7,          %%xmm1,          %%xmm1\n\t"
    "vmulpd                  %%xmm7,          %%xmm2,          %%xmm2\n\t"
    "vmulpd                  %%xmm7,          %%xmm3,          %%xmm3\n\t"

    "vextractf128            $1,              %%ymm9,          %%xmm4\n\t"
    "vextractf128            $1,              %%ymm13,         %%xmm5\n\t"

    "vaddpd                  %%xmm0,          %%xmm9,          %%xmm0\n\t"
    "vaddpd                  %%xmm1,          %%xmm4,          %%xmm1\n\t"
    "vaddpd                  %%xmm2,          %%xmm13,         %%xmm2\n\t"
    "vaddpd                  %%xmm3,          %%xmm5,          %%xmm3\n\t"

    "vmovlpd                 %%xmm0,          (%%rcx)\n\t"
    "vmovhpd                 %%xmm0,          (%%rcx,%%r9)\n\t"
    "vmovlpd                 %%xmm1,          (%%rcx,%%r10)\n\t"
    "vmovhpd                 %%xmm1,          (%%rcx,%%r11)\n\t"
    "vmovlpd                 %%xmm2,          (%%rdx)\n\t"
    "vmovhpd                 %%xmm2,          (%%rdx,%%r9)\n\t"
    "vmovlpd                 %%xmm3,          (%%rdx,%%r10)\n\t"
    "vmovhpd                 %%xmm3,          (%%rdx,%%r11)\n\t"

    "#\n\t"
    "#       Update C(2,:)\n\t"
    "#\n\t"
    "addq                    %%r8,            %%rcx\n\t"
    "addq                    %%r8,            %%rdx\n\t"

    "vmovlpd                 (%%rcx),         %%xmm0,          %%xmm0\n\t"
    "vmovhpd                 (%%rcx,%%r9),    %%xmm0,          %%xmm0\n\t"
    "vmovlpd                 (%%rcx,%%r10),   %%xmm1,          %%xmm1\n\t"
    "vmovhpd                 (%%rcx,%%r11),   %%xmm1,          %%xmm1\n\t"
    "vmovlpd                 (%%rdx),         %%xmm2,          %%xmm2\n\t"
    "vmovhpd                 (%%rdx,%%r9),    %%xmm2,          %%xmm2\n\t"
    "vmovlpd                 (%%rdx,%%r10),   %%xmm3,          %%xmm3\n\t"
    "vmovhpd                 (%%rdx,%%r11),   %%xmm3,          %%xmm3\n\t"

    "vmulpd                  %%xmm7,          %%xmm0,          %%xmm0\n\t"
    "vmulpd                  %%xmm7,          %%xmm1,          %%xmm1\n\t"
    "vmulpd                  %%xmm7,          %%xmm2,          %%xmm2\n\t"
    "vmulpd                  %%xmm7,          %%xmm3,          %%xmm3\n\t"

    "vextractf128            $1,              %%ymm10,         %%xmm4\n\t"
    "vextractf128            $1,              %%ymm14,         %%xmm5\n\t"

    "vaddpd                  %%xmm0,          %%xmm10,         %%xmm0\n\t"
    "vaddpd                  %%xmm1,          %%xmm4,          %%xmm1\n\t"
    "vaddpd                  %%xmm2,          %%xmm14,         %%xmm2\n\t"
    "vaddpd                  %%xmm3,          %%xmm5,          %%xmm3\n\t"

    "vmovlpd                 %%xmm0,          (%%rcx)\n\t"
    "vmovhpd                 %%xmm0,          (%%rcx,%%r9)\n\t"
    "vmovlpd                 %%xmm1,          (%%rcx,%%r10)\n\t"
    "vmovhpd                 %%xmm1,          (%%rcx,%%r11)\n\t"
    "vmovlpd                 %%xmm2,          (%%rdx)\n\t"
    "vmovhpd                 %%xmm2,          (%%rdx,%%r9)\n\t"
    "vmovlpd                 %%xmm3,          (%%rdx,%%r10)\n\t"
    "vmovhpd                 %%xmm3,          (%%rdx,%%r11)\n\t"

    "#\n\t"
    "#       Update C(3,:)\n\t"
    "#\n\t"
    "addq                    %%r8,            %%rcx\n\t"
    "addq                    %%r8,            %%rdx\n\t"

    "vmovlpd                 (%%rcx),         %%xmm0,          %%xmm0\n\t"
    "vmovhpd                 (%%rcx,%%r9),    %%xmm0,          %%xmm0\n\t"
    "vmovlpd                 (%%rcx,%%r10),   %%xmm1,          %%xmm1\n\t"
    "vmovhpd                 (%%rcx,%%r11),   %%xmm1,          %%xmm1\n\t"
    "vmovlpd                 (%%rdx),         %%xmm2,          %%xmm2\n\t"
    "vmovhpd                 (%%rdx,%%r9),    %%xmm2,          %%xmm2\n\t"
    "vmovlpd                 (%%rdx,%%r10),   %%xmm3,          %%xmm3\n\t"
    "vmovhpd                 (%%rdx,%%r11),   %%xmm3,          %%xmm3\n\t"

    "vmulpd                  %%xmm7,          %%xmm0,          %%xmm0\n\t"
    "vmulpd                  %%xmm7,          %%xmm1,          %%xmm1\n\t"
    "vmulpd                  %%xmm7,          %%xmm2,          %%xmm2\n\t"
    "vmulpd                  %%xmm7,          %%xmm3,          %%xmm3\n\t"

    "vextractf128            $1,              %%ymm11,         %%xmm4\n\t"
    "vextractf128            $1,              %%ymm15,         %%xmm5\n\t"

    "vaddpd                  %%xmm0,          %%xmm11,         %%xmm0\n\t"
    "vaddpd                  %%xmm1,          %%xmm4,          %%xmm1\n\t"
    "vaddpd                  %%xmm2,          %%xmm15,         %%xmm2\n\t"
    "vaddpd                  %%xmm3,          %%xmm5,          %%xmm3\n\t"

    "vmovlpd                 %%xmm0,          (%%rcx)\n\t"
    "vmovhpd                 %%xmm0,          (%%rcx,%%r9)\n\t"
    "vmovlpd                 %%xmm1,          (%%rcx,%%r10)\n\t"
    "vmovhpd                 %%xmm1,          (%%rcx,%%r11)\n\t"
    "vmovlpd                 %%xmm2,          (%%rdx)\n\t"
    "vmovhpd                 %%xmm2,          (%%rdx,%%r9)\n\t"
    "vmovlpd                 %%xmm3,          (%%rdx,%%r10)\n\t"
    "vmovhpd                 %%xmm3,          (%%rdx,%%r11)\n\t"

    : // output
    : // input
        "m" (kc),       // 0
        "m" (A),        // 1
        "m" (B),        // 2
        "m" (pAlpha),   // 3
        "m" (pBeta),    // 4
        "m" (C),        // 5
        "m" (incRowC),  // 6
        "m" (incColC)   // 7
    : // register clobber list
        "rax",  "rbx",  "rcx",  "rdx",    "rsi",   "rdi",
        "r8",   "r9",   "r10",  "r11",
        "xmm0", "xmm1", "xmm2",  "xmm3",  "xmm4",  "xmm5",  "xmm6",  "xmm7",
        "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15",
        "memory"
    );
}

#else

void
dgemm_micro(size_t k, double alpha,
            const double *A, const double *B,
            double beta,
            double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    double AB[DGEMM_MR*DGEMM_NR];

    // AB <- A*B
    for (size_t i=0; i<DGEMM_MR*DGEMM_NR; ++i) {
        AB[i] = 0;
    }
    for (size_t l=0; l<k; ++l) {
        for (size_t i=0; i<DGEMM_MR; ++i) {
            for (size_t j=0; j<DGEMM_NR; ++j) {
                AB[i+j*DGEMM_MR] += A[i+l*DGEMM_MR]*B[l*DGEMM_NR+j];
            }
        }
    }
    // C <- beta*C
    if (beta!=1) {
        if (beta!=0) {
            for (size_t i=0; i<DGEMM_MR; ++i) {
                for (size_t j=0; j<DGEMM_NR; ++j) {
                    C[i*incRowC+j*incColC] *= beta;
                }
            }
        } else {
            for (size_t i=0; i<DGEMM_MR; ++i) {
                for (size_t j=0; j<DGEMM_NR; ++j) {
                    C[i*incRowC+j*incColC] = 0;
                }
            }
        }
    }
    // C <- C + alpha*AB
    for (size_t i=0; i<DGEMM_MR; ++i) {
        for (size_t j=0; j<DGEMM_NR; ++j) {
            C[i*incRowC+j*incColC] += alpha*AB[i+j*DGEMM_MR];
        }
    }
}

#endif


void
dgescal(size_t m, size_t n,
        double alpha,
        double *X, size_t incRowX, size_t incColX)
{
    if (alpha==1) {
        return;
    }
    if (alpha!=0) {
        for (size_t i=0; i<m; ++i) {
            for (size_t j=0; j<n; ++j) {
                X[i*incRowX+j*incColX] *= alpha;
            }
        }
    } else {
        for (size_t i=0; i<m; ++i) {
            for (size_t j=0; j<n; ++j) {
                X[i*incRowX+j*incColX] = 0;
            }
        }
    }
}

void
dgeaxpy(size_t m, size_t n,
        double alpha,
        const double *X, size_t incRowX, size_t incColX,
        double *Y, size_t incRowY, size_t incColY)
{
    if (alpha==0) {
        return;
    }
    for (size_t i=0; i<m; ++i) {
        for (size_t j=0; j<n; ++j) {
            Y[i*incRowY+j*incColY] += alpha*X[i*incRowX+j*incColX];
        }
    }
}

void
dgemm_macro(size_t m, size_t n, size_t k, double alpha,
            const double *A, const double *B,
            double beta,
            double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    double AB[DGEMM_MR*DGEMM_NR];

    size_t mb = (m+DGEMM_MR-1) / DGEMM_MR;
    size_t nb = (n+DGEMM_NR-1) / DGEMM_NR;

    size_t mr = m % DGEMM_MR;
    size_t nr = n % DGEMM_NR;

    for (size_t i=0; i<mb; ++i) {
        size_t m_ = (i<mb-1 || mr==0) ? DGEMM_MR
                                      : mr;
        for (size_t j=0; j<nb; ++j) {
            size_t n_ = (j<nb-1 || nr==0) ? DGEMM_NR
                                          : nr;
            if (m_==DGEMM_MR && n_==DGEMM_NR) {
                dgemm_micro(k, alpha,
                            &A[i*DGEMM_MR*k], &B[j*k*DGEMM_NR],
                            beta,
                            &C[i*DGEMM_MR*incRowC+j*DGEMM_NR*incColC],
                            incRowC, incColC);
            } else {
                dgemm_micro(k, alpha,
                            &A[i*DGEMM_MR*k], &B[j*k*DGEMM_NR],
                            0,
                            AB, 1, DGEMM_MR);
                dgescal(DGEMM_MR, DGEMM_NR,
                        beta,
                        &C[i*DGEMM_MR*incRowC+j*DGEMM_NR*incColC],
                        incRowC, incColC);
                dgeaxpy(DGEMM_MR, DGEMM_NR,
                        1,
                        AB, 1, DGEMM_MR,
                        &C[i*DGEMM_MR*incRowC+j*DGEMM_NR*incColC],
                        incRowC, incColC);
            }
        }
    }
}

void
dgemm_frame(size_t m, size_t n, size_t k, double alpha,
            const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
            const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
            double beta,
            double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    if (k==0 || alpha==0) {
        dgescal(m, n, beta, C, incRowC, incColC);
        return;
    }

    size_t mb = (m+DGEMM_MC-1) / DGEMM_MC;
    size_t nb = (n+DGEMM_NC-1) / DGEMM_NC;
    size_t kb = (k+DGEMM_KC-1) / DGEMM_KC;

    size_t mr = m % DGEMM_MC;
    size_t nr = n % DGEMM_NC;
    size_t kr = k % DGEMM_KC;

    double      *Ap      = malloc_aligned(DGEMM_MC*DGEMM_KC*sizeof(*Ap), 32);
    double      *Bp      = malloc_aligned(DGEMM_KC*DGEMM_NC*sizeof(*Bp), 32);

    for (size_t j=0; j<nb; ++j) {
        size_t n_ = (j<nb-1 || nr==0) ? DGEMM_NC
                                      : nr;
        for (size_t l=0; l<kb; ++l) {
            size_t k_ = (l<kb-1 || kr==0) ? DGEMM_KC
                                          : kr;
            double beta_ = (l==0) ? beta
                                  : 1;

            dgepack_B(k_, n_,
                      &B[l*DGEMM_KC*incRowB+j*DGEMM_NC*incColB],
                      incRowB, incColB,
                      Bp);

            for (size_t i=0; i<mb; ++i) {
                size_t m_ = (i<mb-1 || mr==0) ? DGEMM_MC
                                              : mr;
                dgepack_A(m_, k_,
                          &A[i*DGEMM_MC*incRowA+l*DGEMM_KC*incColA],
                          incRowA, incColA,
                          Ap);
                dgemm_macro(m_, n_, k_,
                            alpha,
                            Ap, Bp,
                            beta_,
                            &C[i*DGEMM_MC*incRowC+j*DGEMM_NC*incColC],
                            incRowC, incColC);
            }
        }
    }

    free_aligned(Bp);
    free_aligned(Ap);
}

//-- Function for benchmarking and testing -------------------------------------

double
walltime()
{
   struct tms    ts;
   static double ClockTick=0.0;

   if (ClockTick==0.0) {
        ClockTick = 1.0 / ((double) sysconf(_SC_CLK_TCK));
   }
   return ((double) times(&ts)) * ClockTick;
}

void
randGeMatrix(size_t m, size_t n, double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
    for (size_t j=0; j<n; ++j) {
        for (size_t i=0; i<m; ++i) {
            A[i*incRowA+j*incColA] = ((double)rand()-RAND_MAX/2)*200/RAND_MAX;
        }
    }
}

#define MIN(X,Y)   ((X)<(Y) ? (X) : (Y))
#define MAX(X,Y)   ((X)>(Y) ? (X) : (Y))

double
dgenrm1(size_t m, size_t n, const double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
    double  result = 0;

    for (size_t j=0; j<n; ++j) {
        double sum = 0;
        for (size_t i=0; i<m; ++i) {
            sum += fabs(A[i*incRowA+j*incColA]);
        }
        if (sum>result) {
            result = sum;
        }
    }
    return result;
}

double
err_dgemm(size_t m, size_t n, size_t k,
          double alpha,
          const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
          const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
          double beta,
          const double *C0, ptrdiff_t incRowC0, ptrdiff_t incColC0,
          double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
    double normA = dgenrm1(m, k, A, incRowA, incColA);
    double normB = dgenrm1(k, n, B, incRowB, incColB);
    double normC = dgenrm1(m, n, C, incRowC0, incColC0);
    double normD;
    size_t mn    = (m>n)  ? m  : n;
    size_t mnk   = (mn>k) ? mn : k;

    normA = MAX(normA, fabs(alpha)*normA);
    normC = MAX(normC, fabs(beta)*normC);

    dgeaxpy(m, n, -1.0, C0, incRowC0, incColC0, C, incRowC, incColC);
    normD = dgenrm1(m, n, C, incRowC, incColC);

    return normD/(mnk*normA*normB*normC);
}

void
dcopy(size_t n,
      const double *x, ptrdiff_t incX,
      double *y, ptrdiff_t incY)
{
    for (size_t i=0; i<n; ++i) {
        y[i*incY] = x[i*incX];
    }
}

void
dgecopy(size_t m, size_t n,
        const double *X, ptrdiff_t incRowX, ptrdiff_t incColX,
        double *Y, ptrdiff_t incRowY, ptrdiff_t incColY)
{
    if (incRowX<incColX) {
        for (size_t j=0; j<n; ++j) {
            dcopy(m, &X[j*incColX], incRowX, &Y[j*incColY], incRowY);
        }
    } else {
        for (size_t i=0; i<m; ++i) {
            dcopy(n, &X[i*incRowX], incColX, &Y[i*incRowY], incColY);
        }
    }
}


//------------------------------------------------------------------------------

#ifndef MIN_N
#define MIN_N 100
#endif

#ifndef MAX_N
#define MAX_N 4000
#endif

#ifndef INC_N
#define INC_N 100
#endif

#ifndef MIN_M
#define MIN_M 100
#endif

#ifndef MAX_M
#define MAX_M 4000
#endif

#ifndef INC_M
#define INC_M 100
#endif

#ifndef MIN_K
#define MIN_K 100
#endif

#ifndef MAX_K
#define MAX_K 4000
#endif

#ifndef INC_K
#define INC_K 100
#endif

#ifndef ALPHA
#define ALPHA 1
#endif

#ifndef BETA
#define BETA 1
#endif

#ifndef ROWMAJOR_A
#define ROWMAJOR_A 0
#endif

#ifndef ROWMAJOR_B
#define ROWMAJOR_B 0
#endif

#ifndef ROWMAJOR_C
#define ROWMAJOR_C 0
#endif

double A_[MAX_M*MAX_K];
double B_[MAX_K*MAX_N];
double C_[MAX_M*MAX_N];

double C0[MAX_M*MAX_N];     // reference solution
double C1[MAX_M*MAX_N];     // tested solution


int
main()
{
    randGeMatrix(MAX_M, MAX_K, A_, 1, MAX_M);
    randGeMatrix(MAX_K, MAX_N, B_, 1, MAX_K);
    randGeMatrix(MAX_N, MAX_M, C_, 1, MAX_M);

    printf("#%9s %9s %9s", "m", "n", "k");
    printf(" %12s %12s %17s", "t", "MFLOPS", "Residual Error");
    printf("\n");

    for (size_t m=MIN_M, n=MIN_N, k=MIN_K; n<=MAX_N && m<=MAX_M && k<=MAX_K;
         m+=INC_M, n+=INC_N, k+=INC_K)
    {
        double t, dt, err;
        size_t runs  = 1;
        double ops   = 2.0*m/1000*n/1000*k;

        ptrdiff_t incRowA = (ROWMAJOR_A==1) ? k : 1;
        ptrdiff_t incColA = (ROWMAJOR_A==1) ? 1 : m;

        ptrdiff_t incRowB = (ROWMAJOR_B==1) ? n : 1;
        ptrdiff_t incColB = (ROWMAJOR_B==1) ? 1 : k;

        ptrdiff_t incRowC = (ROWMAJOR_C==1) ? n : 1;
        ptrdiff_t incColC = (ROWMAJOR_C==1) ? 1 : m;

        printf(" %9zu %9zu %9td", m, n, k);

        // compute reference solution
        dgecopy(m, n, C_, 1, MAX_M, C0, incRowC, incColC);
        dgemm_ref(m, n, k,
                  ALPHA,
                  A_, incRowA, incRowA,
                  B_, incRowB, incColB,
                  BETA,
                  C0, incRowC, incColC);

        // benchmark dgemm_frame
        t    = 0;
        runs = 0;
        do {
            dgecopy(m, n, C_, 1, MAX_M, C1, incRowC, incColC);
            dt = walltime();
            dgemm_frame(m, n, k,
                        ALPHA,
                        A_, incRowA, incRowA,
                        B_, incRowB, incColB,
                        BETA,
                        C1, incRowC, incColC);
            dt = walltime() - dt;
            t += dt;
            ++runs;
        } while (t<0.3);
        t /= runs;

        err = err_dgemm(m, n, k,
                        ALPHA,
                        A_, incRowA, incColA,
                        B_, incRowB, incColB,
                        BETA,
                        C0, incRowC, incColC,
                        C1, incRowC, incColC);

        printf(" %12.2e %12.2lf %12.2e %4s", t, ops/t,
               err, (err<DBL_EPSILON) ? "PASS" : "FAIL");

        printf("\n");
    }

    return 0;
}