1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
     100
     101
     102
     103
     104
     105
     106
     107
     108
     109
     110
     111
     112
     113
     114
     115
     116
     117
     118
     119
     120
     121
     122
     123
     124
     125
     126
     127
     128
     129
     130
     131
     132
     133
     134
#ifndef HPC_ULMBLAS_GEMM_H
#define HPC_ULMBLAS_GEMM_H 1

#include <complex>
#include <type_traits>
#include <hpc/aux/isfundamental.h>
#include <hpc/ulmblas/blocksize.h>
#include <hpc/ulmblas/gescal.h>
#include <hpc/ulmblas/pack.h>
#include <hpc/ulmblas/mgemm.h>
#include <hpc/ulmblas/kernels/ugemm.h>
#include <cstdlib>

#ifdef MEM_ALIGN
#   include <xmmintrin.h>
#endif

namespace hpc { namespace ulmblas {

//-----------------------------------------------------------------------------

template <typename T>
typename std::enable_if<hpc::aux::IsFundamental<T>::value,
         T *>::type
malloc(size_t n)
{
#ifdef MEM_ALIGN
    return reinterpret_cast<T *>(_mm_malloc(n*sizeof(T), MEM_ALIGN));
#   else
    return new T[n];
#   endif
}

template <typename T>
typename std::enable_if<! hpc::aux::IsFundamental<T>::value,
         T *>::type
malloc(size_t n)
{
    return new T[n];
}

template <typename T>
typename std::enable_if<hpc::aux::IsFundamental<T>::value,
         void>::type
free(T *block)
{
#ifdef MEM_ALIGN
    _mm_free(reinterpret_cast<void *>(block));
#   else
    delete [] block;
#   endif
}

template <typename T>
typename std::enable_if<! hpc::aux::IsFundamental<T>::value,
         void>::type
free(T *block)
{
    delete [] block;
}

//-----------------------------------------------------------------------------

template <typename Index, typename Alpha,
         typename TA, typename TB,
         typename Beta,
         typename TC>
void
gemm(Index m, Index n, Index k,
     Alpha alpha,
     const TA *A, Index incRowA, Index incColA,
     const TB *B, Index incRowB, Index incColB,
     Beta beta,
     TC *C, Index incRowC, Index incColC)
{
    typedef typename std::common_type<Alpha, TA, TB>::type  T;

    const Index MC = BlockSize<T>::MC;
    const Index NC = BlockSize<T>::NC;
    const Index KC = BlockSize<T>::KC;

    const Index MR = BlockSize<T>::MR;
    const Index NR = BlockSize<T>::NR;

    const Index mb = (m+MC-1) / MC;
    const Index nb = (n+NC-1) / NC;
    const Index kb = (k+KC-1) / KC;

    const Index mc_ = m % MC;
    const Index nc_ = n % NC;
    const Index kc_ = k % KC;

    T *A_ = malloc<T>(MC*KC + MR);
    T *B_ = malloc<T>(KC*NC + NR);

    if (alpha==Alpha(0) || k==0) {
        gescal(m, n, beta, C, incRowC, incColC);
        return;
    }

    for (Index j=0; j<nb; ++j) {
        Index nc = (j!=nb-1 || nc_==0) ? NC : nc_;

        for (Index l=0; l<kb; ++l) {
            Index   kc  = (l!=kb-1 || kc_==0) ? KC : kc_;
            Beta beta_  = (l==0) ? beta : Beta(1);

            pack_B(kc, nc,
                   &B[l*KC*incRowB+j*NC*incColB],
                   incRowB, incColB,
                   B_);

            for (Index i=0; i<mb; ++i) {
                Index mc = (i!=mb-1 || mc_==0) ? MC : mc_;

                pack_A(mc, kc,
                       &A[i*MC*incRowA+l*KC*incColA],
                       incRowA, incColA,
                       A_);

                mgemm(mc, nc, kc,
                      T(alpha), A_, B_, beta_,
                      &C[i*MC*incRowC+j*NC*incColC],
                      incRowC, incColC);
            }
        }
    }
    free(A_);
    free(B_);
}

} } // namespace ulmblas, hpc

#endif // HPC_ULMBLAS_GEMM_H