// Code extracted from ulmBLAS: https://github.com/michael-lehn/ulmBLAS-core
#ifndef GEMM_HPP
#define GEMM_HPP
#include <algorithm>
#include <cstdlib>
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace foo {
//-- new with alignment --------------------------------------------------------
void *
malloc_(std::size_t alignment, std::size_t size)
{
    alignment = std::max(alignment, alignof(void *));
    size     += alignment;
    void *ptr  = std::malloc(size);
    void *ptr2 = (void *)(((uintptr_t)ptr + alignment) & ~(alignment-1));
    void **vp  = (void**) ptr2 - 1;
    *vp        = ptr;
    return ptr2;
}
void
free_(void *ptr)
{
    std::free(*((void**)ptr-1));
}
//-- Config --------------------------------------------------------------------
// SIMD-Register width in bits
// SSE:         128
// AVX/FMA:     256
// AVX-512:     512
#ifndef SIMD_REGISTER_WIDTH
#define SIMD_REGISTER_WIDTH 256
#endif
#ifdef HAVE_FMA
#   ifndef BS_D_MR
#   define BS_D_MR 4
#   endif
#   ifndef BS_D_NR
#   define BS_D_NR 12
#   endif
#   ifndef BS_D_MC
#   define BS_D_MC 256
#   endif
#   ifndef BS_D_KC
#   define BS_D_KC 512
#   endif
#   ifndef BS_D_NC
#   define BS_D_NC 4092
#   endif
#endif
#ifndef BS_D_MR
#define BS_D_MR 4
#endif
#ifndef BS_D_NR
#define BS_D_NR 8
#endif
#ifndef BS_D_MC
#define BS_D_MC 256
#endif
#ifndef BS_D_KC
#define BS_D_KC 256
#endif
#ifndef BS_D_NC
#define BS_D_NC 4096
#endif
template <typename T>
struct BlockSize
{
    static constexpr int MC = 64;
    static constexpr int KC = 64;
    static constexpr int NC = 256;
    static constexpr int MR = 8;
    static constexpr int NR = 8;
    static constexpr int rwidth = 0;
    static constexpr int align  = alignof(T);
    static constexpr int vlen   = 0;
    static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
    static_assert(MC % MR == 0, "MC must be a multiple of MR.");
    static_assert(NC % NR == 0, "NC must be a multiple of NR.");
};
template <>
struct BlockSize<double>
{
    static constexpr int MC     = BS_D_MC;
    static constexpr int KC     = BS_D_KC;
    static constexpr int NC     = BS_D_NC;
    static constexpr int MR     = BS_D_MR;
    static constexpr int NR     = BS_D_NR;
    static constexpr int rwidth = SIMD_REGISTER_WIDTH;
    static constexpr int align  = rwidth / 8;
#if defined(HAVE_AVX) || defined(HAVE_FMA) || defined(HAVE_GCCVEC)
    static constexpr int vlen   = rwidth / (8*sizeof(double));
#else
    static constexpr int vlen   = 0;
#endif
    static_assert(MC>0 && KC>0 && NC>0 && MR>0 && NR>0, "Invalid block size.");
    static_assert(MC % MR == 0, "MC must be a multiple of MR.");
    static_assert(NC % NR == 0, "NC must be a multiple of NR.");
    static_assert(rwidth % sizeof(double) == 0, "SIMD register width not sane.");
};
//-- aux routines --------------------------------------------------------------
template <typename Alpha, typename MX, typename MY>
void
geaxpy(const Alpha &alpha, const MX &X, MY &Y)
{
    assert(X.size1()==Y.size1());
    assert(X.size2()==Y.size2());
    typedef typename MX::size_type  size_type;
    for (size_type j=0; j<X.size2(); ++j) {
        for (size_type i=0; i<X.size1(); ++i) {
            Y(i,j) += alpha*X(i,j);
        }
    }
}
template <typename Alpha, typename MX>
void
gescal(const Alpha &alpha, MX &X)
{
    typedef typename MX::size_type  size_type;
    for (size_type j=0; j<X.size2(); ++j) {
        for (size_type i=0; i<X.size1(); ++i) {
            X(i,j) *= alpha;
        }
    }
}
template <typename Index, typename Alpha, typename TX, typename TY>
void
geaxpy(Index m, Index n,
       const Alpha &alpha,
       const TX *X, Index incRowX, Index incColX,
       TY       *Y, Index incRowY, Index incColY)
{
    for (Index j=0; j<n; ++j) {
        for (Index i=0; i<m; ++i) {
            Y[i*incRowY+j*incColY] += alpha*X[i*incRowX+j*incColX];
        }
    }
}
template <typename Index, typename Alpha, typename TX>
void
gescal(Index m, Index n,
       const Alpha &alpha,
       TX *X, Index incRowX, Index incColX)
{
    for (Index j=0; j<n; ++j) {
        for (Index i=0; i<m; ++i) {
            X[i*incRowX+j*incColX] *= alpha;
        }
    }
}
template <typename IndexType, typename MX, typename MY>
void
gecopy(IndexType m, IndexType n,
       const MX *X, IndexType incRowX, IndexType incColX,
       MY *Y, IndexType incRowY, IndexType incColY)
{
    for (IndexType j=0; j<n; ++j) {
        for (IndexType i=0; i<m; ++i) {
            Y[i*incRowY+j*incColY] = X[i*incRowX+j*incColX];
        }
    }
}
//-- Micro Kernel --------------------------------------------------------------
template <typename Index, typename T>
typename std::enable_if<BlockSize<T>::vlen == 0,
         void>::type
ugemm(Index kc, T alpha, const T *A, const T *B, T beta,
      T *C, Index incRowC, Index incColC)
{
    const Index MR = BlockSize<T>::MR;
    const Index NR = BlockSize<T>::NR;
    T P[BlockSize<T>::MR*BlockSize<T>::NR];
    for (Index l=0; l<MR*NR; ++l) {
        P[l] = 0;
    }
    for (Index l=0; l<kc; ++l) {
        for (Index j=0; j<NR; ++j) {
            for (Index i=0; i<MR; ++i) {
                P[i+j*MR] += A[i+l*MR]*B[l*NR+j];
            }
        }
    }
    for (Index j=0; j<NR; ++j) {
        for (Index i=0; i<MR; ++i) {
            C[i*incRowC+j*incColC] *= beta;
            C[i*incRowC+j*incColC] += alpha*P[i+j*MR];
        }
    }
}
#if defined HAVE_AVX
#include "avx.hpp"
#elif defined HAVE_FMA
#include "fma.hpp"
#elif defined HAVE_GCCVEC
#include "gccvec.hpp"
#endif
#ifndef HAVE_GCCVEC
template <typename T>
void
utrlsm(const T *A, T *B)
{
    typedef std::size_t IndexType;
    const IndexType MR = BlockSize<T>::MR;
    const IndexType NR = BlockSize<T>::NR;
    T   C_[MR*NR];
    for (IndexType i=0; i<MR; ++i) {
        for (IndexType j=0; j<NR; ++j) {
            C_[i+j*MR] = B[i*NR+j];
        }
    }
    for (IndexType i=0; i<MR; ++i) {
        for (IndexType j=0; j<NR; ++j) {
            C_[i+j*MR] *= A[i];
            for (IndexType l=i+1; l<MR; ++l) {
                C_[l+j*MR] -= A[l]*C_[i+j*MR];
            }
        }
        A += MR;
    }
    for (IndexType i=0; i<MR; ++i) {
        for (IndexType j=0; j<NR; ++j) {
            B[i*NR+j] = C_[i+j*MR];
        }
   }
}
#else
template <typename T>
typename std::enable_if<BlockSize<T>::vlen != 0,
         void>::type
utrlsm(const T *A, T *B)
{
    typedef std::size_t IndexType;
    typedef T           vx __attribute__((vector_size(BlockSize<T>::rwidth/8)));
    static constexpr IndexType vlen = BlockSize<T>::vlen;
    static constexpr IndexType MR   = BlockSize<T>::MR;
    static constexpr IndexType NR   = BlockSize<T>::NR/vlen;
    A = (const T*) __builtin_assume_aligned (A, BlockSize<T>::align);
    B = ( T*)      __builtin_assume_aligned (B, BlockSize<T>::align);
    vx C_[MR*NR];
    vx *B_ = (vx *)B;
    for (IndexType i=0; i<MR*NR; ++i) {
        C_[i] = B_[i];
    }
    for (IndexType i=0; i<MR; ++i) {
        for (IndexType j=0; j<NR; ++j) {
            C_[i*NR+j] *= A[i];
        }
        for (IndexType l=i+1; l<MR; ++l) {
            for (IndexType j=0; j<NR; ++j) {
                C_[l*NR+j] -= A[l]*C_[i*NR+j];
            }
        }
        A += MR;
    }
    for (IndexType i=0; i<MR*NR; ++i) {
        B_[i] = C_[i];
    }
}
#endif
//-- Macro Kernel --------------------------------------------------------------
template <typename Index, typename T, typename Beta, typename TC>
void
mgemm(Index mc, Index nc, Index kc, const T &alpha,
      const T *A, const T *B, Beta beta,
      TC *C, Index incRowC, Index incColC)
{
    const Index MR = BlockSize<T>::MR;
    const Index NR = BlockSize<T>::NR;
    const Index mp  = (mc+MR-1) / MR;
    const Index np  = (nc+NR-1) / NR;
    const Index mr_ = mc % MR;
    const Index nr_ = nc % NR;
    #if defined(_OPENMP)
    #pragma omp parallel for
    #endif
    for (Index j=0; j<np; ++j) {
        const Index nr = (j!=np-1 || nr_==0) ? NR : nr_;
        T C_[BlockSize<T>::MR*BlockSize<T>::NR];
        for (Index i=0; i<mp; ++i) {
            const Index mr = (i!=mp-1 || mr_==0) ? MR : mr_;
            if (mr==MR && nr==NR) {
                ugemm(kc, alpha,
                      &A[i*kc*MR], &B[j*kc*NR],
                      beta,
                      &C[i*MR*incRowC+j*NR*incColC],
                      incRowC, incColC);
            } else {
                std::fill_n(C_, MR*NR, T(0));
                ugemm(kc, alpha,
                      &A[i*kc*MR], &B[j*kc*NR],
                      T(0),
                      C_, Index(1), MR);
                gescal(mr, nr, beta,
                       &C[i*MR*incRowC+j*NR*incColC],
                       incRowC, incColC);
                geaxpy(mr, nr, T(1), C_, Index(1), MR,
                       &C[i*MR*incRowC+j*NR*incColC],
                       incRowC, incColC);
            }
        }
    }
}
template <typename IndexType, typename T>
void
mtrlsm(IndexType mc, IndexType nc, const T &alpha, const T *A_, T *B_)
{
    const IndexType MR = BlockSize<T>::MR;
    const IndexType NR = BlockSize<T>::NR;
    const IndexType mp = (mc+MR-1) / MR;
    const IndexType np = (nc+NR-1) / NR;
    #if defined(_OPENMP)
    #pragma omp parallel for
    #endif
    for (IndexType j=0; j<np; ++j) {
        IndexType ia = 0;
        for (IndexType i=0; i<mp; ++i) {
            IndexType kc    = i*MR;
            ugemm(kc,
                  T(-1), &A_[ia*MR*MR], &B_[j*mc*NR],
                  alpha,
                  &B_[(j*mc+kc)*NR], NR, IndexType(1));
            utrlsm(&A_[(ia*MR+kc)*MR], &B_[(j*mc+kc)*NR]);
            ia += i+1;
        }
    }
}
//-- Packing blocks ------------------------------------------------------------
template <typename MA, typename T>
void
pack_A(const MA &A, T *p)
{
    typedef typename MA::size_type  size_type;
    size_type mc = A.size1();
    size_type kc = A.size2();
    size_type MR = BlockSize<T>::MR;
    size_type mp = (mc+MR-1) / MR;
    for (size_type j=0; j<kc; ++j) {
        for (size_type l=0; l<mp; ++l) {
            for (size_type i0=0; i0<MR; ++i0) {
                size_type i  = l*MR + i0;
                size_type nu = l*MR*kc + j*MR + i0;
                p[nu]        = (i<mc) ? A(i,j) : T(0);
            }
        }
    }
}
template <typename MB, typename T>
void
pack_B(const MB &B, T *p)
{
    typedef typename MB::size_type  size_type;
    size_type kc = B.size1();
    size_type nc = B.size2();
    size_type NR = BlockSize<T>::NR;
    size_type np = (nc+NR-1) / NR;
    for (size_type l=0; l<np; ++l) {
        for (size_type j0=0; j0<NR; ++j0) {
            for (size_type i=0; i<kc; ++i) {
                size_type j  = l*NR+j0;
                size_type nu = l*NR*kc + i*NR + j0;
                p[nu]        = (j<nc) ? B(i,j) : T(0);
            }
        }
    }
}
template <typename T, typename MB>
void
unpack_B(const T *p, MB &B)
{
    typedef typename MB::size_type  size_type;
    size_type kc = B.size1();
    size_type nc = B.size2();
    size_type NR = BlockSize<T>::NR;
    size_type np = (nc+NR-1) / NR;
    for (size_type l=0; l<np; ++l) {
        for (size_type j0=0; j0<NR; ++j0) {
            for (size_type i=0; i<kc; ++i) {
                size_type j  = l*NR+j0;
                size_type nu = l*NR*kc + i*NR + j0;
                if (j<nc) {
                    B(i,j) = p[nu];
                }
            }
        }
    }
}
template <typename ML, typename T>
void
pack_L(const ML &L, T *p)
{
    typedef typename ML::size_type  size_type;
    assert(L.size1()==L.size2());
    size_type mc = L.size1();
    size_type MR = BlockSize<T>::MR;
    size_type mp = (mc+MR-1) / MR;
    for (size_type j=0; j<mp; ++j) {
        for (size_type j0=0; j0<MR; ++j0) {
            for (size_type i=j; i<mp; ++i) {
                for (size_type i0=0; i0<MR; ++i0) {
                    size_type I  = i*MR+i0;
                    size_type J  = j*MR+j0;
                    size_type nu = (i+1)*i/2*MR*MR + j*MR*MR + j0*MR +i0;
                    p[nu]        = (I==J) ? T(1)
                                          : (I>=mc || J>=mc) ? T(0)
                                                             : (I>J) ? L(I,J)
                                                                     : T(0);
                }
            }
        }
    }
}
//-- Frame routine -------------------------------------------------------------
template <typename Alpha, typename MatrixA, typename MatrixB,
         typename Beta, typename MatrixC>
void
gemm(Alpha alpha, const MatrixA &A, const MatrixB &B, Beta beta, MatrixC &C)
{
    assert(A.size2()==B.size1());
    namespace ublas = boost::numeric::ublas;
    typedef typename MatrixC::size_type  size_type;
    typedef typename MatrixA::value_type TA;
    typedef typename MatrixB::value_type TB;
    typedef typename MatrixC::value_type TC;
    typedef typename std::common_type<Alpha, TA, TB>::type  T;
    const size_type MC = BlockSize<T>::MC;
    const size_type NC = BlockSize<T>::NC;
    const size_type MR = BlockSize<T>::MR;
    const size_type NR = BlockSize<T>::NR;
    const size_type m = C.size1();
    const size_type n = C.size2();
    const size_type k = A.size2();
    const size_type KC = BlockSize<T>::KC;
    const size_type mb = (m+MC-1) / MC;
    const size_type nb = (n+NC-1) / NC;
    const size_type kb = (k+KC-1) / KC;
    const size_type mc_ = m % MC;
    const size_type nc_ = n % NC;
    const size_type kc_ = k % KC;
    if (m==0 || n==0 || ((alpha==Alpha(0) || k==0) && (beta==Beta(1)))) {
        return;
    }
    TC *C_ = &C(0,0);
    const size_type incRowC = &C(1,0) - &C(0,0);
    const size_type incColC = &C(0,1) - &C(0,0);
    T *A_ = (T*) malloc_(BlockSize<T>::align, sizeof(T)*(MC*KC+MR));
    T *B_ = (T*) malloc_(BlockSize<T>::align, sizeof(T)*(KC*NC+NR));
    if (alpha==Alpha(0) || k==0) {
        gescal(beta, C);
        return;
    }
    for (size_type j=0; j<nb; ++j) {
        size_type nc = (j!=nb-1 || nc_==0) ? NC : nc_;
        for (size_type l=0; l<kb; ++l) {
            size_type kc = (l!=kb-1 || kc_==0) ? KC : kc_;
            Beta beta_   = (l==0) ? beta : Beta(1);
            const auto Bs = subrange(B, l*KC, l*KC+kc, j*NC, j*NC+nc);
            pack_B(Bs, B_);
            for (size_type i=0; i<mb; ++i) {
                size_type mc = (i!=mb-1 || mc_==0) ? MC : mc_;
                const auto As = subrange(A, i*MC, i*MC+mc, l*KC, l*KC+kc);
                pack_A(As, A_);
                mgemm(mc, nc, kc,
                      T(alpha), A_, B_, beta_,
                      &C_[i*MC*incRowC+j*NC*incColC],
                      incRowC, incColC);
            }
        }
    }
    free_(A_);
    free_(B_);
}
template <typename Alpha, typename MatrixA, typename MatrixB>
void
trlsm(const Alpha   &alpha, bool unitDiag, const MatrixA &A, MatrixB &B)
{
    assert(A.size2()==A.size1());
    namespace ublas = boost::numeric::ublas;
    typedef typename MatrixA::size_type  size_type;
    typedef typename MatrixA::value_type TA;
    typedef typename MatrixB::value_type TB;
    typedef typename std::common_type<Alpha, TA, TB>::type   T_;
    typedef typename std::remove_const<T_>::type             T;
    const size_type MC = BlockSize<T>::MC;
    const size_type NC = BlockSize<T>::NC;
    const size_type MR = BlockSize<T>::MR;
    const size_type NR = BlockSize<T>::NR;
    const size_type m = B.size1();
    const size_type n = B.size2();
    const size_type mb = (m+MC-1) / MC;
    const size_type nb = (n+NC-1) / NC;
    const size_type mc_ = m % MC;
    const size_type nc_ = n % NC;
    if (m==0 || n==0) {
        return;
    }
    const size_type incRowB = &B(1,0) - &B(0,0);
    const size_type incColB = &B(0,1) - &B(0,0);
    if (alpha==Alpha(0)) {
        gescal(Alpha(0), B);
        return;
    }
    T *A_ = (T*) malloc_(BlockSize<T>::align, sizeof(T)*(MC*MC+MR));
    T *B_ = (T*) malloc_(BlockSize<T>::align, sizeof(T)*(MC*NC+NR));
    for (size_type j=0; j<nb; ++j) {
        size_type nc = (j!=nb-1 || nc_==0) ? NC : nc_;
        for (size_type i=0; i<mb; ++i) {
            size_type mc  = (i!=mb-1 || mc_==0) ? MC : mc_;
            Alpha  alpha_ = (i==0) ? alpha : Alpha(1);
            auto Bs = subrange(B, i*MC, i*MC+mc, j*NC, j*NC+nc);
            pack_B(Bs, B_);
            const auto Ls = subrange(A, i*MC, i*MC+mc, i*MC, i*MC+mc);
            pack_L(Ls, A_);
            mtrlsm(mc, nc, T(alpha_), A_, B_);
            unpack_B(B_, Bs);
            for (size_type l=i+1; l<mb; ++l) {
                mc  = (l!=mb-1 || mc_==0) ? MC : mc_;
                const auto As = subrange(A, l*MC, l*MC+mc, i*MC, i*MC+mc);
                pack_A(As, A_);
                mgemm(mc, nc, MC, T(-1), A_, B_, alpha_,
                      &B(l*MC,j*NC), incRowB, incColB);
            }
        }
    }
    free_(A_);
    free_(B_);
}
} // namespace foo
#endif