#include <stdlib.h>
#include <stdio.h>
#include <stddef.h>
#include <stdbool.h>
#include <math.h>
#ifndef DGEMM_MC
#define DGEMM_MC 4
#endif
#ifndef DGEMM_NC
#define DGEMM_NC 6
#endif
#ifndef DGEMM_KC
#define DGEMM_KC 5
#endif
#ifndef DGEMM_MR
#define DGEMM_MR 2
#endif
#ifndef DGEMM_NR
#define DGEMM_NR 3
#endif
#if (DGEMM_MC % DGEMM_MR != 0)
#error "DGEMM_MC must be a multiple of DEGMM_MR."
#endif
#if (DGEMM_NC % DGEMM_NR != 0)
#error "DGEMM_NC must be a multiple of DEGMM_NR."
#endif
void
pack_A(size_t mc, size_t kc,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
double *A_)
{
size_t mb = mc / DGEMM_MR;
for (size_t ib=0; ib<mb; ++ib) {
for (size_t j=0; j<kc; ++j) {
for (size_t i=0; i<DGEMM_MR; ++i) {
A_[ib*DGEMM_MR*kc + j*DGEMM_MR + i]
= A[(ib*DGEMM_MR+i)*incRowA + j*incColA];
}
}
}
if (mb*DGEMM_MR<mc) {
size_t mr = mc % DGEMM_MR;
for (size_t j=0; j<kc; ++j) {
for (size_t i=0; i<mr; ++i) {
A_[mb*DGEMM_MR*kc + j*DGEMM_MR + i]
= A[(mb*DGEMM_MR+i)*incRowA + j*incColA];
}
for (size_t i=mr; i<DGEMM_MR; ++i) {
A_[mb*DGEMM_MR*kc + j*DGEMM_MR + i] = 0;
}
}
}
}
void
pack_B(size_t kc, size_t nc,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
double *B_)
{
size_t nb = nc / DGEMM_NR;
for (size_t jb=0; jb<nb; ++jb) {
for (size_t i=0; i<kc; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
B_[jb*DGEMM_NR*kc + i*DGEMM_NR + j]
= B[i*incRowB + (jb*DGEMM_NR+j)*incColB];
}
}
}
if (nb*DGEMM_NR<nc) {
size_t nr = nc % DGEMM_NR;
for (size_t i=0; i<kc; ++i) {
for (size_t j=0; j<nr; ++j) {
B_[nb*DGEMM_NR*kc + i*DGEMM_NR + j]
= B[i*incRowB + (nb*DGEMM_NR+j)*incColB];
}
for (size_t j=nr; j<DGEMM_NR; ++j) {
B_[nb*DGEMM_NR*kc + i*DGEMM_NR + j] = 0;
}
}
}
}
void
dgemm_micro_ref(size_t k, double alpha,
const double *A, const double *B,
double beta,
double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
double R[DGEMM_MR*DGEMM_NR];
for (size_t i=0; i<DGEMM_MR; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
R[i*DGEMM_NR+j] = 0;
}
}
for (size_t l=0; l<k; ++l) {
for (size_t i=0; i<DGEMM_MR; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
R[i*DGEMM_NR+j] += A[i+l*DGEMM_MR]*B[l*DGEMM_NR+j];
}
}
}
for (size_t i=0; i<DGEMM_MR; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
R[i*DGEMM_NR+j] *= alpha;
}
}
if (beta==0) {
for (size_t i=0; i<DGEMM_MR; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
C[i*incRowC+j*incColC] = R[i*DGEMM_NR+j];
}
}
} else {
for (size_t i=0; i<DGEMM_MR; ++i) {
for (size_t j=0; j<DGEMM_NR; ++j) {
C[i*incRowC+j*incColC] *= beta;
C[i*incRowC+j*incColC] += R[i*DGEMM_NR+j];
}
}
}
}
void
ge_dscal(size_t m, size_t n,
double alpha,
double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
if (alpha!=0) {
for (size_t j=0; j<n; ++j) {
for (size_t i=0; i<m; ++i) {
A[i*incRowA + j*incColA] *= alpha;
}
}
} else {
for (size_t j=0; j<n; ++j) {
for (size_t i=0; i<m; ++i) {
A[i*incRowA + j*incColA] = 0;
}
}
}
}
void
ge_daxpy(size_t m, size_t n, double alpha,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
double *B, ptrdiff_t incRowB, ptrdiff_t incColB)
{
if (m==0 || n==0 || alpha==0) {
return;
}
for (size_t j=0; j<n; ++j) {
for (size_t i=0; i<m; ++i) {
B[i*incRowB + j*incColB] += alpha*A[i*incRowA + j*incColA];
}
}
}
void
dgemm_macro(size_t mc, size_t nc, size_t kc,
double alpha,
const double *A, const double *B,
double beta,
double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
double AB[DGEMM_MR*DGEMM_NR];
size_t mb = (mc+DGEMM_MR-1) / DGEMM_MR;
size_t nb = (nc+DGEMM_NR-1) / DGEMM_NR;
size_t mr_ = mc % DGEMM_MR;
size_t nr_ = nc % DGEMM_NR;
for (size_t ib=0; ib<mb; ++ib) {
size_t mr = (ib<mb-1 || mr_==0) ? DGEMM_MR
: mr_;
for (size_t jb=0; jb<nb; ++jb) {
size_t nr = (jb<nb-1 || nr_==0) ? DGEMM_NR
: nr_;
if (mr==DGEMM_MR && nr==DGEMM_NR) {
dgemm_micro_ref(kc, alpha,
&A[ib*DGEMM_MR*kc], &B[jb*kc*DGEMM_NR],
beta,
&C[ib*DGEMM_MR*incRowC+jb*DGEMM_NR*incColC],
incRowC, incColC);
} else {
dgemm_micro_ref(kc, alpha,
&A[ib*DGEMM_MR*kc], &B[jb*kc*DGEMM_NR],
0,
AB, 1, DGEMM_MR);
ge_dscal(mr, nr,
beta,
&C[ib*DGEMM_MR*incRowC + jb*DGEMM_NR*incColC],
incRowC, incColC);
ge_daxpy(mr, nr,
1,
AB, 1, DGEMM_MR,
&C[ib*DGEMM_MR*incRowC + jb*DGEMM_NR*incColC],
incRowC, incColC);
}
}
}
}
void
initDGeMatrix(size_t m, size_t n, bool withNan,
double *A,
ptrdiff_t incRowA, ptrdiff_t incColA)
{
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
A[i*incRowA + j*incColA] = withNan ? nan("")
: i*n + j + 1;
}
}
}
void
printDGeMatrix(size_t m, size_t n,
const double *A,
ptrdiff_t incRowA, ptrdiff_t incColA)
{
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
printf("%9.2lf ", A[i*incRowA + j*incColA]);
}
printf("\n");
}
printf("\n");
}
void
test_macro(size_t m, size_t n, size_t k,
double alpha,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
double beta,
double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
double *A_ = malloc(DGEMM_MC*DGEMM_KC*sizeof(double));
double *B_ = malloc(DGEMM_KC*DGEMM_NC*sizeof(double));
if (!A_ || !B_) {
abort();
}
size_t mc = m < DGEMM_MC ? m : DGEMM_MC;
size_t nc = n < DGEMM_NC ? n : DGEMM_NC;
size_t kc = k < DGEMM_KC ? k : DGEMM_KC;
printf("mc = %zu\n", mc);
printf("nc = %zu\n", nc);
printf("kc = %zu\n", kc);
printf("Block from A =\n");
printDGeMatrix(mc, kc, A, incRowA, incColA);
printf("Block from B =\n");
printDGeMatrix(kc, nc, B, incRowB, incColB);
printf("Block from C =\n");
printDGeMatrix(mc, nc, C, incRowC, incColC);
pack_A(mc, kc, A, incRowA, incColA, A_);
pack_B(kc, nc, B, incRowB, incColB, B_);
dgemm_macro(mc, nc, kc, alpha, A_, B_, beta, C, incRowC, incColC);
printf("C <- %5.2lf * C + %5.2lf * A * B\n", beta, alpha);
printf("C =\n");
printDGeMatrix(mc, nc, C, incRowC, incColC);
free(A_);
free(B_);
}
#ifndef COLMAJOR
#define COLMAJOR 1
#endif
int
main()
{
size_t m = 2;
size_t n = 4;
size_t k = 4;
ptrdiff_t incRowA = COLMAJOR ? 1 : k;
ptrdiff_t incColA = COLMAJOR ? m : 1;
ptrdiff_t incRowB = COLMAJOR ? 1 : n;
ptrdiff_t incColB = COLMAJOR ? k : 1;
ptrdiff_t incRowC = COLMAJOR ? 1 : n;
ptrdiff_t incColC = COLMAJOR ? m : 1;
double *A = malloc(m*k*sizeof(double));
double *B = malloc(k*n*sizeof(double));
double *C = malloc(m*n*sizeof(double));
if (!A || !B || !C) {
abort();
}
initDGeMatrix(m, k, false, A, incRowA, incColA);
initDGeMatrix(k, n, false, B, incRowB, incColB);
initDGeMatrix(m, n, true, C, incRowC, incColC);
double alpha = 1;
double beta = 0;
test_macro(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
beta,
C, incRowC, incColC);
free(A);
free(B);
free(C);
}