#include <cstdio>
#include <cassert>
#include <chrono>
#include <complex>
#include <cmath>
#include <limits>
#include <random>
#include "common.h"
#define BLASINT int64_t
//------------------------------------------------------------------------------
extern "C" {
void strsm_(const char *side, const char *uplo, const char *transa,
const char *diag,
const BLASINT *m, const BLASINT *n,
const float *alpha,
const float *a, const BLASINT *lda,
float *b, const BLASINT *ldb);
void dtrsm_(const char *side, const char *uplo, const char *transa,
const char *diag,
const BLASINT *m, const BLASINT *n,
const double *alpha,
const double *a, const BLASINT *lda,
double *b, const BLASINT *ldb);
void ctrsm_(const char *side, const char *uplo, const char *transa,
const char *diag,
const BLASINT *m, const BLASINT *n,
const float *alpha,
const float *a, const BLASINT *lda,
float *b, const BLASINT *ldb);
void ztrsm_(const char *side, const char *uplo, const char *transa,
const char *diag,
const BLASINT *m, const BLASINT *n,
const double *alpha,
const double *a, const BLASINT *lda,
double *b, const BLASINT *ldb);
} // extern "C"
//------------------------------------------------------------------------------
template <typename Index>
void
f77_trsm(const char Side,
const char Uplo, const char TransA,
const char Diag,
const Index m, const Index n,
const float alpha,
const float *A, const Index ldA,
float *B, const Index ldB)
{
BLASINT M = m;
BLASINT N = n;
BLASINT LDA = ldA;
BLASINT LDB = ldB;
strsm_(&Side, &Uplo, &TransA, &Diag,
&M, &N,
&alpha,
A, &LDA,
B, &LDB);
}
template <typename Index>
void
f77_trsm(const char Side,
const char Uplo, const char TransA,
const char Diag,
const Index m, const Index n,
const double alpha,
const double *A, const Index ldA,
double *B, const Index ldB)
{
BLASINT M = m;
BLASINT N = n;
BLASINT LDA = ldA;
BLASINT LDB = ldB;
dtrsm_(&Side, &Uplo, &TransA, &Diag,
&M, &N,
&alpha,
A, &LDA,
B, &LDB);
}
template <typename Index>
void
f77_trsm(const char Side,
const char Uplo, const char TransA,
const char Diag,
const Index m, const Index n,
const std::complex<float> alpha,
const std::complex<float> *A, const Index ldA,
std::complex<float> *B, const Index ldB)
{
BLASINT M = m;
BLASINT N = n;
BLASINT LDA = ldA;
BLASINT LDB = ldB;
ctrsm_(&Side, &Uplo, &TransA, &Diag,
&M, &N,
(const float*)&alpha,
(const float*)A, &LDA,
(float*)B, &LDB);
}
template <typename Index>
void
f77_trsm(const char Side,
const char Uplo, const char TransA,
const char Diag,
const Index m, const Index n,
const std::complex<double> alpha,
const std::complex<double> *A, const Index ldA,
std::complex<double> *B, const Index ldB)
{
BLASINT M = m;
BLASINT N = n;
BLASINT LDA = ldA;
BLASINT LDB = ldB;
ztrsm_(&Side, &Uplo, &TransA, &Diag,
&M, &N,
(const double*)&alpha,
(const double*)A, &LDA,
(double*)B, &LDB);
}
//------------------------------------------------------------------------------
int
main()
{
typedef flens::TrMatrix<flens::FullStorage<TYPE_A> > TrMatrixA;
typedef flens::GeMatrix<flens::FullStorage<TYPE_B> > GeMatrixB;
TYPE_ALPHA alpha = ALPHA;
const std::size_t min_m = MIN_M;
const std::size_t min_n = MIN_N;
const std::size_t max_m = MAX_M;
const std::size_t max_n = MAX_N;
const std::size_t inc_m = INC_M;
const std::size_t inc_n = INC_N;
std::printf("#%5s %5s %5s ", "m", "n", "k");
std::printf("%20s %9s", "FLENS/ulmBLAS: t", "MFLOPS");
std::printf("%20s %9s %9s", BLAS_LIB ": t", "MFLOPS", "Residual");
std::printf("\n");
WallTime<double> walltime;
for (std::size_t m=min_m, n=min_n;
m<=max_m && n<=max_n;
m += inc_m, n += inc_n)
{
TrMatrixA A(m, flens::Lower, flens::Unit);
GeMatrixB B1(m, n);
GeMatrixB B2(m, n);
fill(A);
fill(B1);
B2 = B1;
walltime.tic();
flens::blas::sm(flens::Left, flens::NoTrans, alpha, A, B1);
double t1 = walltime.toc();
walltime.tic();
f77_trsm('L', 'L', 'N', 'U',
B2.numRows(), B2.numCols(),
alpha,
A.data(), A.leadingDimension(),
B2.data(), B2.leadingDimension());
double t2 = walltime.toc();
double res = asumDiff(B1, B2);
double mflops = (n * m * (m + 1.0 ) / 2.0)
+ (n * m * (m - 1.0 ) / 2.0);
mflops /= 1000000;
std::printf(" %5ld %5ld ", m, n);
std::printf("%20.4lf %9.2lf", t1, mflops/t1);
std::printf("%20.4lf %9.2lf", t2, mflops/t2);
std::printf(" %9.1e", res);
std::printf("\n");
}
}