#include <ulmaux.h>
#include <ulmblas.h>
#include <assert.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <string.h>
#ifndef DTRSM
#define DTRSM dtrsm
#endif
#ifndef SEED_RAND
#define SEED_RAND 0
#endif
#ifndef TOL_ERR
#define TOL_ERR 1
#endif
void
dcopy_ref(size_t n,
const double *x, ptrdiff_t incX,
double *y, ptrdiff_t incY)
{
for (size_t i=0; i<n; ++i) {
y[i*incY] = x[i*incX];
}
}
void
dgescal_ref(size_t m, size_t n, double alpha,
double *X, ptrdiff_t incRowX, ptrdiff_t incColX)
{
for (size_t j=0; j<n; ++j) {
dscal(m, alpha, &X[j*incColX], incRowX);
}
}
void
dgemm_ref(size_t m, size_t n, size_t k,
double alpha,
const double *A, ptrdiff_t incRowA, ptrdiff_t incColA,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
double beta,
double *C, ptrdiff_t incRowC, ptrdiff_t incColC)
{
if (m==0 || n==0 || ((alpha==0 || k==0) && beta==1)) {
return;
}
if (beta!=1) {
if (beta!=0) {
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
C[i*incRowC + j*incColC] *= beta;
}
}
} else {
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
C[i*incRowC + j*incColC] = 0;
}
}
}
}
if (alpha==0) {
return;
}
for (size_t j=0; j<n; ++j) {
for (size_t l=0; l<k; ++l) {
for (size_t i=0; i<m; ++i) {
C[i*incRowC + j*incColC] += alpha*A[i*incRowA + l*incColA]
*B[l*incRowB + j*incColB];
}
}
}
}
void
dtrmm_ref(size_t m, size_t n, bool lower, bool unit, double alpha,
const double *A_, ptrdiff_t incRowA, ptrdiff_t incColA,
double *Y, ptrdiff_t incRowY, ptrdiff_t incColY)
{
double *A = malloc(m*m*sizeof(double));
double *X = malloc(m*n*sizeof(double));
if (!A || !X) {
abort();
}
ptrdiff_t ldA = m;
ptrdiff_t ldX = m;
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<m; ++j) {
A[i+j*ldA] = 0;
if (lower && j<i) {
A[i+j*ldA] = A_[i*incRowA+j*incColA];
}
if (!lower && j>i) {
A[i+j*ldA] = A_[i*incRowA+j*incColA];
}
if (i==j) {
A[i+j*ldA] = unit ? 1: A_[i*incRowA+j*incColA];
}
}
}
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<n; ++j) {
X[i + j*ldX] = Y[i*incRowY + j*incColY];
}
}
dgemm_ref(m, n, m, alpha, A, 1, ldA, X, 1, ldX, 0, Y, incRowY, incColY);
free(A);
free(X);
}
double
dtrsm_err(size_t m, size_t n, bool lower, bool unit,
double alpha,
const double *A_, ptrdiff_t incRowA, ptrdiff_t incColA,
const double *B, ptrdiff_t incRowB, ptrdiff_t incColB,
const double *X0, ptrdiff_t incRowX0, ptrdiff_t incColX0,
double *X, ptrdiff_t incRowX, ptrdiff_t incColX)
{
if (m==0 || n==0) {
return 0;
}
dgeaxpy(m, n, -1, X0, incRowX0, incColX0, X, incRowX, incColX);
double nrmD = dgenrm_inf(m, n, X, incRowX, incColX);
if (isnan(nrmD)) {
return nrmD;
}
if (nrmD==0) {
return 0;
}
double *A = malloc(m*m*sizeof(double));
if (!A) {
abort();
}
ptrdiff_t ldA = m;
for (size_t i=0; i<m; ++i) {
for (size_t j=0; j<m; ++j) {
A[i+j*ldA] = 0;
if (lower && j<i) {
A[i+j*ldA] = A_[i*incRowA+j*incColA];
}
if (!lower && j>i) {
A[i+j*ldA] = A_[i*incRowA+j*incColA];
}
if (i==j) {
A[i+j*ldA] = unit ? 1: A_[i*incRowA+j*incColA];
}
}
}
double nrmA = dgenrm_inf(m, m, A, 1, ldA);
double nrmB = dgenrm_inf(m, n, B, incRowB, incColB);
if (nrmA<1) {
nrmA = 1;
}
if (fabs(alpha)>1) {
nrmA *= fabs(alpha);
}
if (nrmB<1) {
nrmB = 1;
}
size_t mn = m>n ? m : n;
double err = nrmD / (nrmA*nrmB*mn*DBL_EPSILON);
free(A);
return err;
}
void
makeDDiagDom(size_t n, bool lower, bool unit,
double *A, ptrdiff_t incRowA, ptrdiff_t incColA)
{
for (size_t i=0; i<n; ++i) {
double scal = 0;
if (lower) {
for (size_t j=0; j<i; ++j) {
scal += fabs(A[i*incRowA+j*incColA]);
}
} else {
for (size_t j=i+1; j<n; ++j) {
scal += fabs(A[i*incRowA+j*incColA]);
}
}
if (scal==0) {
scal = 1;
}
if (unit) {
scal = 1/scal;
if (lower) {
for (size_t j=0; j<i; ++j) {
A[i*incRowA+j*incColA] *= scal;
}
} else {
for (size_t j=i+1; j<n; ++j) {
A[i*incRowA+j*incColA] *= scal;
}
}
} else {
A[i*(incRowA+incColA)] += 1 + 2*fabs(scal);
}
}
}
#ifndef CHECK_DIM
#define CHECK_DIM 5, 7, 33, 23, 0, 1
#endif
#ifndef CHECK_INC
#define CHECK_INC 1, -2
#endif
#ifndef CHECK_SCALAR
#define CHECK_SCALAR 1, -1.5, 0
#endif
bool
get_next_param(bool reset, size_t *m, size_t *n, bool *lower, bool *unit,
double *alpha,
ptrdiff_t *incRowA, ptrdiff_t *incColA,
ptrdiff_t *incRowX, ptrdiff_t *incColX)
{
static size_t dim[] = {CHECK_DIM};
static ptrdiff_t inc[] = {CHECK_INC};
static double scalar[] = {CHECK_SCALAR};
static bool colMajorA;
static bool colMajorX;
static bool lower_;
static bool unit_;
static size_t idx_m;
static size_t idx_n;
static size_t idx_alpha;
static size_t idx_incRowA;
static size_t idx_incColA;
static size_t idx_incRowX;
static size_t idx_incColX;
if (reset) {
colMajorA = false;
colMajorX = false;
lower_ = true;
unit_ = false;
idx_m = 0;
idx_n = 0;
idx_alpha = 0;
idx_incRowA = 0;
idx_incColA = 0;
idx_incRowX = 0;
idx_incColX = 0;
}
*m = dim[idx_m];
*n = dim[idx_n];
*alpha = scalar[idx_alpha];
*lower = lower_;
*unit = unit_;
if (colMajorA) {
*incRowA = inc[idx_incRowA];
*incColA = dim[idx_m]*ptrdiff_abs(inc[idx_incRowA]*inc[idx_incColA]);
} else {
*incRowA = dim[idx_m]*ptrdiff_abs(inc[idx_incRowA]*inc[idx_incColA]);
*incColA = inc[idx_incColA];
}
if (colMajorX) {
*incRowX = inc[idx_incRowX];
*incColX = dim[idx_m]*ptrdiff_abs(inc[idx_incRowX]*inc[idx_incColX]);
} else {
*incRowX = dim[idx_n]*ptrdiff_abs(inc[idx_incRowX]*inc[idx_incColX]);
*incColX = inc[idx_incColX];
}
if (!colMajorA) {
colMajorA = true;
return true;
}
colMajorA = false;
if (!colMajorX) {
colMajorX = true;
return true;
}
colMajorX = false;
if (!lower_) {
lower_ = true;
return true;
}
lower_ = false;
if (!unit_) {
unit_ = true;
return true;
}
unit_ = false;
if (++idx_incRowA < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incRowA = 0;
if (++idx_incColA < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incColA = 0;
if (++idx_incRowX < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incRowX = 0;
if (++idx_incColX < sizeof(inc)/sizeof(ptrdiff_t)) {
return true;
}
idx_incColX = 0;
if (++idx_m < sizeof(dim)/sizeof(size_t)) {
return true;
}
idx_m = 0;
if (++idx_n < sizeof(dim)/sizeof(size_t)) {
return true;
}
idx_n = 0;
if (++idx_alpha < sizeof(scalar)/sizeof(double)) {
return true;
}
idx_alpha = 0;
return false;
}
int
check_dtrsm(int argc, char **argv)
{
size_t m, n;
double alpha;
ptrdiff_t incRowA, incColA, incRowX, incColX;
bool lower, unit;
bool more, reset = true, pass;
unsigned seed_rand = SEED_RAND;
for (int i=0; i<argc; ++i) {
if (!strcmp(argv[i], "seed")) {
assert(i+1<argc);
assert(sscanf(argv[i+1], "%u", &seed_rand)==1);
}
}
srand(seed_rand);
printf("%8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %8s\n",
"m", "n", "lower", "unit", "alpha", "incRowA", "incColA",
"incRowX", "incColX", "err", "res");
do {
more = get_next_param(reset, &m, &n, &lower, &unit, &alpha,
&incRowA, &incColA,
&incRowX, &incColX);
reset = false;
size_t bufsize_A = 1 + m*ptrdiff_abs(incRowA) + m*ptrdiff_abs(incColA);
size_t bufsize_X = 1 + m*ptrdiff_abs(incRowX) + n*ptrdiff_abs(incColX);
double *buf_A = malloc(bufsize_A*sizeof(double));
double *buf_X0 = malloc(bufsize_X*sizeof(double));
double *buf_B = malloc(bufsize_X*sizeof(double));
double *buf_X = malloc(bufsize_X*sizeof(double));
if (!buf_A || !buf_X0 || !buf_B || !buf_X) {
printf("malloc failed.\n");
abort();
}
dfill_nan(bufsize_A, buf_A);
dfill_nan(bufsize_X, buf_X0);
dfill_nan(bufsize_X, buf_B);
dfill_nan(bufsize_X, buf_X);
double *A = buf_A;
double *X0 = buf_X0;
double *B = buf_B;
double *X = buf_X;
if (incRowA<0) {
A -= m*incRowA;
}
if (incColA<0) {
A -= m*incColA;
}
if (incRowX<0) {
X0 -= m*incRowX;
B -= m*incRowX;
X -= m*incRowX;
}
if (incColX<0) {
X0 -= n*incColX;
B -= n*incColX;
X -= n*incColX;
}
randDGeMatrix(m, m, false, A, incRowA, incColA);
makeDDiagDom(m, lower, unit, A, incRowA, incColA);
if (alpha!=0) {
randDGeMatrix(m, n, false, X0, incRowX, incColX);
dgecopy(m, n, X0, incRowX, incColX, B, incRowX, incColX);
dtrmm_ref(m, n, lower, unit, 1/alpha, A, incRowA, incColA,
B, incRowX, incColX);
} else {
dgescal_ref(m, n, 0, X0, incRowX, incColX);
randDGeMatrix(m, n, true, B, incRowX, incColX);
}
dgecopy(m, n, B, incRowX, incColX, X, incRowX, incColX);
DTRSM(m, n, lower, unit, alpha, A, incRowA, incColA,
X, incRowX, incColX);
double err = dtrsm_err(m, n, lower, unit,
alpha,
A, incRowA, incColA,
B, incRowX, incColX,
X0, incRowX, incColX,
X, incRowX, incColX);
pass = err<TOL_ERR;
if (!pass) {
printf("A =\n");
printDGeMatrix(m, m, A, incRowA, incColA);
printf("B =\n");
printDGeMatrix(m, n, B, incRowX, incColX);
printf("X0 =\n");
printDGeMatrix(m, n, X0, incRowX, incColX);
printf("X - X0 =\n");
printfDGeMatrix("%e ", m, n, X, incRowX, incColX);
}
printf("%8zu %8zu %8s %8s %8.2lf %8td %8td %8td %8td %7.1e %8s\n",
m, n,
lower ? "true" : "false",
unit ? "true" : "false",
alpha,
incRowA, incColA,
incRowX, incColX,
err, pass ? "PASS" : "FAILED");
free(buf_A);
free(buf_X0);
free(buf_B);
free(buf_X);
if (!pass) {
break;
}
} while (more);
return !pass;
}
int
main(int argc, char **argv)
{
if (argc<2 || (strcmp(argv[1], "check") && strcmp(argv[1], "bench"))) {
fprintf(stderr, "usage:\n");
fprintf(stderr, " %s check [seed]\n", argv[0]);
return 1;
}
if (!strcmp(argv[1], "check")) {
return check_dtrsm(argc-2, argv+2);
}
}