#include <ulmaux.h>
#include <ulmblas.h>
#include <assert.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <string.h>
void
dgemm_ref(size_t m, size_t n, size_t k,
double alpha,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
double beta,
double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
if (m==0 || n==0 || ((alpha==0 || k==0) && beta==1)) {
return;
}
if (beta!=1) {
if (beta!=0) {
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
C[i*incRowC + j*incColC] *= beta;
}
}
} else {
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
C[i*incRowC + j*incColC] = 0;
}
}
}
}
if (alpha==0) {
return;
}
for (size_t j=0; j<n; ++j) {
for (size_t l=0; l<k; ++l) {
for (size_t i=0; i<m; ++i) {
C[i*incRowC + j*incColC] += alpha*A[i*incRowA + l*incColA]
*B[l*incRowB + j*incColB];
}
}
}
}
double
dgemm_err(size_t m, size_t n, size_t k,
double alpha,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
const double *C0, ptrdiff_t incRowC0, ptrdiff_t incColC0,
double beta,
const double *CRef, ptrdiff_t incRowCRef, ptrdiff_t incColCRef,
double *CSol, ptrdiff_t incRowCSol, ptrdiff_t incColCSol)
{
double nrmC0 = beta==0 ? 0 : dgenrm_inf(m, n, C0, incRowCRef, incColCSol);
double nrmA = alpha==0 ? 0 : dgenrm_inf(m, k, A, incRowA, incColA);
double nrmB = alpha==0 ? 0 : dgenrm_inf(k, n, B, incRowB, incColB);
size_t maxMN = m<n ? n : m;
size_t maxMNK = maxMN<k ? k : maxMN;
dgeaxpy(m, n, -1,
CRef, incRowCRef, incColCRef,
CSol, incRowCSol, incColCSol);
double nrmDiff = dgenrm_inf(m, n, CSol, incRowCSol, incColCSol);
if (nrmDiff==0) {
return 0;
}
return nrmDiff /
(DBL_EPSILON*(maxMNK*fabs(alpha)*nrmA*nrmB + maxMN*fabs(beta)*nrmC0));
}
#ifndef CHECK_DIM
#define CHECK_DIM 5, 33, 23, 0, 1
#endif
#ifndef CHECK_INC
#define CHECK_INC 1, -2
#endif
#ifndef CHECK_SCALAR
#define CHECK_SCALAR -1.5, 0, 1
#endif
bool
get_next_param(bool reset,
size_t *m, size_t *n, size_t *k,
double *alpha,
ptrdiff_t *incRowA, ptrdiff_t *incColA,
ptrdiff_t *incRowB, ptrdiff_t *incColB,
double *beta,
ptrdiff_t *incRowC, ptrdiff_t *incColC)
{
static size_t dim[] = {CHECK_DIM};
static ptrdiff_t inc[] = {CHECK_INC};
static double scalar[] = {CHECK_SCALAR};
static bool colMajorA;
static bool colMajorB;
static bool colMajorC;
static size_t idx_m;
static size_t idx_n;
static size_t idx_k;
static size_t idx_alpha;
static size_t idx_beta;
static size_t idx_incRowA;
static size_t idx_incColA;
static size_t idx_incRowB;
static size_t idx_incColB;
static size_t idx_incRowC;
static size_t idx_incColC;
if (reset) {
colMajorA = false;
colMajorB = false;
colMajorC = false;
idx_m = 0;
idx_n = 0;
idx_k = 0;
idx_alpha = 0;
idx_beta = 0;
idx_incRowA = 0;
idx_incColA = 0;
idx_incRowB = 0;
idx_incColB = 0;
idx_incRowC = 0;
idx_incColC = 0;
}
*m = dim[idx_m];
*n = dim[idx_n];
*k = dim[idx_k];
*alpha = scalar[idx_alpha];
*beta = scalar[idx_beta];
if (colMajorA) {
*incRowA = inc[idx_incRowA];
*incColA = dim[idx_m]*ptrdiff_abs(inc[idx_incRowA]*inc[idx_incColA]);
} else {
*incRowA = dim[idx_k]*ptrdiff_abs(inc[idx_incRowA]*inc[idx_incColA]);
*incColA = inc[idx_incColA];
}
if (colMajorB) {
*incRowB = inc[idx_incRowB];
*incColB = dim[idx_k]*ptrdiff_abs(inc[idx_incRowB]*inc[idx_incColB]);
} else {
*incRowB = dim[idx_n]*ptrdiff_abs(inc[idx_incRowB]*inc[idx_incColB]);
*incColB = inc[idx_incColB];
}
if (colMajorC) {
*incRowC = inc[idx_incRowC];
*incColC = dim[idx_m]*ptrdiff_abs(inc[idx_incRowC]*inc[idx_incColC]);
} else {
*incRowC = dim[idx_n]*ptrdiff_abs(inc[idx_incRowC]*inc[idx_incColC]);
*incColC = inc[idx_incColC];
}
if (!colMajorA) {
colMajorA = true;
return true;
}
colMajorA = false;
if (!colMajorB) {
colMajorB = true;
return true;
}
colMajorB = false;
if (!colMajorC) {
colMajorC = true;
return true;
}
colMajorC = false;
if (++idx_incRowA < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incRowA = 0;
if (++idx_incColA < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incColA = 0;
if (++idx_incRowB < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incRowB = 0;
if (++idx_incColB < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incColB = 0;
if (++idx_incRowC < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incRowC = 0;
if (++idx_incColC < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incColC = 0;
if (++idx_m < sizeof(dim)/sizeof(size_t)) {
return true;
}
idx_m = 0;
if (++idx_n < sizeof(dim)/sizeof(size_t)) {
return true;
}
idx_n = 0;
if (++idx_k < sizeof(dim)/sizeof(size_t)) {
return true;
}
idx_k = 0;
if (++idx_alpha < sizeof(scalar)/sizeof(double)) {
return true;
}
idx_alpha = 0;
if (++idx_beta < sizeof(scalar)/sizeof(double)) {
return true;
}
idx_beta = 0;
return false;
}
#ifndef SEED_RAND
#define SEED_RAND 0
#endif
#ifndef TOL_ERR
#define TOL_ERR 2
#endif
int
check_dgemm()
{
size_t m, n, k;
ptrdiff_t incRowA, incColA;
ptrdiff_t incRowB, incColB;
ptrdiff_t incRowC, incColC;
double alpha, beta;
size_t count = 0;
bool more, pass = false, reset = true;
srand(SEED_RAND);
printf("%8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %8s\n",
"m", "n", "k",
"alpha",
"incRowA", "incColA", "incRowB", "incColB",
"beta",
"incRowC", "incColC",
"err", "res");
do {
more = get_next_param(reset,
&m, &n, &k,
&alpha,
&incRowA, &incColA,
&incRowB, &incColB,
&beta,
&incRowC, &incColC);
reset = false;
size_t bufsize_A = 1 + m*ptrdiff_abs(incRowA) + k*ptrdiff_abs(incColA);
size_t bufsize_B = 1 + k*ptrdiff_abs(incRowB) + n*ptrdiff_abs(incColB);
size_t bufsize_C = 1 + m*ptrdiff_abs(incRowC) + n*ptrdiff_abs(incColC);
double *buf_A = malloc(bufsize_A*sizeof(double));
double *buf_B = malloc(bufsize_B*sizeof(double));
double *buf_C0 = malloc(bufsize_C*sizeof(double));
double *buf_CRef = malloc(bufsize_C*sizeof(double));
double *buf_CSol = malloc(bufsize_C*sizeof(double));
if (!buf_A || !buf_B || !buf_C0 || !buf_CRef || !buf_CSol) {
abort();
}
double *A = buf_A;
double *B = buf_B;
double *C0 = buf_C0;
double *CRef = buf_CRef;
double *CSol = buf_CSol;
if (incRowA<0) {
A -= m*incRowA;
}
if (incColA<0) {
A -= k*incColA;
}
if (incRowB<0) {
B -= k*incRowB;
}
if (incColB<0) {
B -= n*incColB;
}
if (incRowC<0) {
C0 -= m*incRowC;
CRef -= m*incRowC;
CSol -= m*incRowC;
}
if (incColC<0) {
C0 -= n*incColC;
CRef -= n*incColC;
CSol -= n*incColC;
}
randDGeMatrix(m, k, alpha==0, A, incRowA, incColA);
randDGeMatrix(k, n, alpha==0, B, incRowB, incColB);
randDGeMatrix(m, n, beta==0, C0, incRowC, incColC);
dgecopy(m, n, C0, incRowC, incColC, CRef, incRowC, incColC);
dgemm_ref(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
beta,
CRef, incRowC, incColC);
dgecopy(m, n, C0, incRowC, incColC, CSol, incRowC, incColC);
dgemm(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
beta,
CSol, incRowC, incColC);
double err = dgemm_err(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
C0, incRowC, incColC,
beta,
CRef, incRowC, incColC,
CSol, incRowC, incColC);
pass = err<TOL_ERR;
printf("%8zu %8zu %8zu %8.2lf %8td %8td %8td %8td %8.2lf %8td %8td "
"%7.1e %8s\n",
m, n, k,
alpha,
incRowA, incColA,
incRowB, incColB,
beta,
incRowC, incColC,
err, pass ? "PASS" : "FAILED");
if (!pass) {
printf("alpha = %8.2lf, beta = %8.2lf\n", alpha, beta);
printf("C0=\n");
printDGeMatrix(m, n, C0, incRowC, incColC);
printf("A=\n");
printDGeMatrix(m, k, A, incRowA, incColA);
printf("B=\n");
printDGeMatrix(k, n, B, incRowB, incColB);
printf("CRef=\n");
printDGeMatrix(m, n, CRef, incRowC, incColC);
printf("CSol - CRef=\n");
printfDGeMatrix("%e ", m, n, CSol, incRowC, incColC);
break;
}
free(buf_A);
free(buf_B);
free(buf_C0);
free(buf_CRef);
free(buf_CSol);
++count;
} while (more);
if (pass) {
printf("Passed all %zu tests.\n", count);
}
return !pass;
}
#ifndef COLMAJOR
#define COLMAJOR 1
#endif
#ifndef MAX_M
#define MAX_M 1000
#endif
#ifndef MAX_N
#define MAX_N 1000
#endif
#ifndef MAX_K
#define MAX_K 1000
#endif
#ifndef ALPHA
#define ALPHA 1
#endif
#ifndef BETA
#define BETA 1
#endif
int
bench_dgemm(int argc, char **argv)
{
bool colmajorA = COLMAJOR;
bool colmajorB = COLMAJOR;
bool colmajorC = COLMAJOR;
double alpha = ALPHA;
double beta = BETA;
for (int i=0; i<argc; ++i) {
if (!strcmp(argv[i], "colmajorA")) {
colmajorA = true;
}
if (!strcmp(argv[i], "rowmajorA")) {
colmajorA = false;
}
if (!strcmp(argv[i], "colmajorB")) {
colmajorB = true;
}
if (!strcmp(argv[i], "rowmajorB")) {
colmajorB = false;
}
if (!strcmp(argv[i], "colmajorC")) {
colmajorC = true;
}
if (!strcmp(argv[i], "rowmajorC")) {
colmajorC = false;
}
if (!strcmp(argv[i], "alpha")) {
assert(i+1<argc);
assert(sscanf(argv[i+1], "%lf", &alpha)==1);
}
if (!strcmp(argv[i], "beta")) {
assert(i+1<argc);
assert(sscanf(argv[i+1], "%lf", &beta)==1);
}
}
srand(SEED_RAND);
printf("#colmajorA = %d\n", colmajorA);
printf("#colmajorB = %d\n", colmajorB);
printf("#colmajorC = %d\n", colmajorC);
printf("#alpha = %lf\n", alpha);
printf("#beta = %lf\n", beta);
double *A = malloc(MAX_M*MAX_K*sizeof(double));
double *B = malloc(MAX_K*MAX_N*sizeof(double));
double *C0 = malloc(MAX_M*MAX_N*sizeof(double));
double *CRef = malloc(MAX_M*MAX_N*sizeof(double));
double *CSol = malloc(MAX_M*MAX_N*sizeof(double));
if (!A || !B || !C0 || !CRef || !CSol) {
abort();
}
printf("#%4s %4s %4s", "m", "n", "k");
printf("%10s %10s ", "time ref", "mflops ref");
printf("%10s %10s %7s ", "time 1", "mflops 1", "err");
printf("\n");
size_t count = 0, count_passed = 0;
for (size_t m=100, n=100, k=100;
m<=MAX_M && n<=MAX_N && k<=MAX_K;
m+=100, n+=100, k+=100)
{
ptrdiff_t incRowA = colmajorA ? 1 : k;
ptrdiff_t incColA = colmajorA ? m : 1;
ptrdiff_t incRowB = colmajorB ? 1 : n;
ptrdiff_t incColB = colmajorB ? k : 1;
ptrdiff_t incRowC = colmajorC ? 1 : n;
ptrdiff_t incColC = colmajorC ? m : 1;
double mflop = 2.*m*k*n/1000000;
randDGeMatrix(m, k, alpha==0, A, incRowA, incColA);
randDGeMatrix(k, n, alpha==0, B, incRowB, incColB);
randDGeMatrix(m, n, beta==0, C0, incRowC, incColC);
printf(" %4zu %4zu %4zu ", m, n, k);
{
double t = 0;
size_t runs = 0;
while (t<0.1 || runs<3) {
dgecopy(m, n, C0, incRowC, incColC, CRef, incRowC, incColC);
double t0 = walltime();
dgemm_ref(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
beta,
CRef, incRowC, incColC);
t += walltime() - t0;
++runs;
}
t /= runs;
printf("%10.2lf %10.2lf ", t, mflop/t);
}
{
double t = 0;
size_t runs = 0;
while (t<0.1 || runs<3) {
dgecopy(m, n, C0, incRowC, incColC, CSol, incRowC, incColC);
double t0 = walltime();
dgemm(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
beta,
CSol, incRowC, incColC);
t += walltime() - t0;
++runs;
}
t /= runs;
double err = dgemm_err(m, n, k,
alpha,
A, incRowA, incColA,
B, incRowB, incColB,
C0, incRowC, incColC,
beta,
CRef, incRowC, incColC,
CSol, incRowC, incColC);
bool pass = err<TOL_ERR;
if (pass) {
++count_passed;
}
++count;
printf("%10.2lf %10.2lf %7.1e ", t, mflop/t, err);
}
printf("\n");
fflush(stdout);
}
printf("Passed %zu of %zu tests.\n", count_passed, count);
free(A);
free(B);
free(C0);
free(CRef);
free(CSol);
return count!=count_passed;
}
int
main(int argc, char **argv)
{
if (argc<2 || (strcmp(argv[1], "check") && strcmp(argv[1], "bench"))) {
fprintf(stderr, "usage:\n");
fprintf(stderr, " %s check\n", argv[0]);
fprintf(stderr, " %s bench [colmajor | rowmajor]"
" [alpha <value>]"
" [beta <value>]"
"\n", argv[0]);
return 1;
}
if (!strcmp(argv[1], "check")) {
return check_dgemm();
}
if (!strcmp(argv[1], "bench")) {
return bench_dgemm(argc-2, argv+2);
}
}