#ifndef HPC_GEMM_BLOCKED_H
#define HPC_GEMM_BLOCKED_H 1
#include <complex>
#include <type_traits>
#include "ulmblas.h"
namespace blocked {
template <typename T>
struct BlockSize
{
static const int MC = 256;
static const int KC = 512;
static const int NC = 4096;
static const int MR = 8;
static const int NR = 8;
static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
static_assert(MC % MR == 0, "MC must be a multiple of MR.");
static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <>
struct BlockSize<float>
{
static const int MC = 256;
static const int KC = 512;
static const int NC = 4096;
static const int MR = 8;
static const int NR = 8;
static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
static_assert(MC % MR == 0, "MC must be a multiple of MR.");
static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <>
struct BlockSize<double>
{
static const int MC = 256;
static const int KC = 256;
static const int NC = 4096;
static const int MR = 4;
static const int NR = 4;
static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
static_assert(MC % MR == 0, "MC must be a multiple of MR.");
static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <>
struct BlockSize<std::complex<float> >
{
static const int MC = 256;
static const int KC = 256;
static const int NC = 4096;
static const int MR = 4;
static const int NR = 8;
static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
static_assert(MC % MR == 0, "MC must be a multiple of MR.");
static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <>
struct BlockSize<std::complex<double> >
{
static const int MC = 256;
static const int KC = 128;
static const int NC = 4096;
static const int MR = 4;
static const int NR = 4;
static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
static_assert(MC % MR == 0, "MC must be a multiple of MR.");
static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <typename T, typename Index>
void
pack_A(Index mc, Index kc,
const T *A, Index incRowA, Index incColA,
T *p)
{
Index MR = BlockSize<T>::MR;
Index mp = (mc+MR-1) / MR;
for (Index j=0; j<kc; ++j) {
for (Index l=0; l<mp; ++l) {
for (Index i0=0; i0<MR; ++i0) {
Index i = l*MR + i0;
Index nu = l*MR*kc + j*MR + i0;
p[nu] = (i<mc) ? A[i*incRowA+j*incColA]
: T(0);
}
}
}
}
template <typename T, typename Index>
void
pack_B(Index kc, Index nc,
const T *B, Index incRowB, Index incColB,
T *p)
{
Index NR = BlockSize<T>::NR;
Index np = (nc+NR-1) / NR;
for (Index l=0; l<np; ++l) {
for (Index j0=0; j0<NR; ++j0) {
for (Index i=0; i<kc; ++i) {
Index j = l*NR+j0;
Index nu = l*NR*kc + i*NR + j0;
p[nu] = (j<nc) ? B[i*incRowB+j*incColB]
: T(0);
}
}
}
}
template <typename T, typename Index>
void
ugemm(Index kc, T alpha,
const T *A, const T *B,
T beta,
T *C, Index incRowC, Index incColC)
{
const Index MR = BlockSize<T>::MR;
const Index NR = BlockSize<T>::NR;
T P[BlockSize<T>::MR*BlockSize<T>::NR];
for (Index l=0; l<MR*NR; ++l) {
P[l] = 0;
}
for (Index l=0; l<kc; ++l) {
for (Index j=0; j<NR; ++j) {
for (Index i=0; i<MR; ++i) {
P[i+j*MR] += A[i+l*MR]*B[l*NR+j];
}
}
}
for (Index j=0; j<NR; ++j) {
for (Index i=0; i<MR; ++i) {
C[i*incRowC+j*incColC] *= beta;
C[i*incRowC+j*incColC] += alpha*P[i+j*MR];
}
}
}
template <typename T, typename Index>
void
mgemm(Index mc, Index nc, Index kc,
T alpha,
const T *A, const T *B,
T beta,
T *C, Index incRowC, Index incColC)
{
const Index MR = BlockSize<T>::MR;
const Index NR = BlockSize<T>::NR;
T C_[BlockSize<T>::MR*BlockSize<T>::NR];
const Index mp = (mc+MR-1) / MR;
const Index np = (nc+NR-1) / NR;
const Index mr_ = mc % MR;
const Index nr_ = nc % NR;
for (Index j=0; j<np; ++j) {
const Index nr = (j!=np-1 || nr_==0) ? NR : nr_;
for (Index i=0; i<mp; ++i) {
const Index mr = (i!=mp-1 || mr_==0) ? MR : mr_;
if (mr==MR && nr==NR) {
ugemm(kc, alpha,
&A[i*kc*MR], &B[j*kc*NR],
beta,
&C[i*MR*incRowC+j*NR*incColC],
incRowC, incColC);
} else {
ugemm(kc, alpha,
&A[i*kc*MR], &B[j*kc*NR],
T(0),
C_, Index(1), MR);
ulmBLAS::gescal(mr, nr, beta,
&C[i*MR*incRowC+j*NR*incColC],
incRowC, incColC);
ulmBLAS::geaxpy(mr, nr, T(1), C_, Index(1), MR,
&C[i*MR*incRowC+j*NR*incColC],
incRowC, incColC);
}
}
}
}
template <typename T, typename Index>
void
gemm(Index m, Index n, Index k,
T alpha,
const T *A, Index incRowA, Index incColA,
const T *B, Index incRowB, Index incColB,
T beta,
T *C, Index incRowC, Index incColC)
{
const Index MC = BlockSize<T>::MC;
const Index NC = BlockSize<T>::NC;
const Index KC = BlockSize<T>::KC;
const Index mb = (m+MC-1) / MC;
const Index nb = (n+NC-1) / NC;
const Index kb = (k+KC-1) / KC;
const Index mc_ = m % MC;
const Index nc_ = n % NC;
const Index kc_ = k % KC;
T *A_ = new T[MC*KC];
T *B_ = new T[KC*NC];
if (alpha==T(0) || k==0) {
ulmBLAS::gescal(m, n, beta, C, incRowC, incColC);
return;
}
for (Index j=0; j<nb; ++j) {
Index nc = (j!=nb-1 || nc_==0) ? NC : nc_;
for (Index l=0; l<kb; ++l) {
Index kc = (l!=kb-1 || kc_==0) ? KC : kc_;
T beta_ = (l==0) ? beta : T(1);
pack_B(kc, nc,
&B[l*KC*incRowB+j*NC*incColB],
incRowB, incColB,
B_);
for (Index i=0; i<mb; ++i) {
Index mc = (i!=mb-1 || mc_==0) ? MC : mc_;
pack_A(mc, kc,
&A[i*MC*incRowA+l*KC*incColA],
incRowA, incColA,
A_);
mgemm(mc, nc, kc,
alpha, A_, B_, beta_,
&C[i*MC*incRowC+j*NC*incColC],
incRowC, incColC);
}
}
}
delete [] A_;
delete [] B_;
}
} // namespace blocked
#endif // HPC_GEMM_BLOCKED_H