1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
     100
     101
     102
     103
     104
     105
     106
     107
     108
     109
     110
     111
     112
     113
     114
     115
     116
     117
     118
     119
     120
     121
     122
     123
     124
     125
     126
     127
     128
     129
     130
     131
     132
     133
     134
     135
     136
     137
     138
     139
     140
     141
     142
     143
     144
     145
     146
     147
     148
     149
     150
     151
     152
     153
     154
     155
     156
     157
     158
#ifndef HPC_ULMBLAS_GEMM_H
#define HPC_ULMBLAS_GEMM_H 1

#include <complex>
#include <type_traits>
#include <hpc/aux/isfundamental.h>
#include <hpc/ulmblas/blocksize.h>
#include <hpc/ulmblas/gescal.h>
#include <hpc/ulmblas/pack.h>
#include <hpc/ulmblas/mgemm.h>
#include <hpc/ulmblas/kernels/ugemm.h>
#include <cstdlib>
#if defined(_OPENMP)
#include <omp.h>
#endif
#ifdef GLOBAL_THREAD_POOL
#include <hpc/mt/thread_pool.h>
#endif
#ifdef GLOBAL_THREAD_POOL
extern hpc::mt::ThreadPool GLOBAL_THREAD_POOL;
#endif

#ifdef MEM_ALIGN
#   include <xmmintrin.h>
#endif

namespace hpc { namespace ulmblas {

//-----------------------------------------------------------------------------

template <typename T>
typename std::enable_if<hpc::aux::IsFundamental<T>::value,
         T *>::type
malloc(size_t n)
{
#ifdef MEM_ALIGN
    return reinterpret_cast<T *>(_mm_malloc(n*sizeof(T), MEM_ALIGN));
#   else
    return new T[n];
#   endif
}

template <typename T>
typename std::enable_if<! hpc::aux::IsFundamental<T>::value,
         T *>::type
malloc(size_t n)
{
    return new T[n];
}

template <typename T>
typename std::enable_if<hpc::aux::IsFundamental<T>::value,
         void>::type
free(T *block)
{
#ifdef MEM_ALIGN
    _mm_free(reinterpret_cast<void *>(block));
#   else
    delete [] block;
#   endif
}

template <typename T>
typename std::enable_if<! hpc::aux::IsFundamental<T>::value,
         void>::type
free(T *block)
{
    delete [] block;
}

//-----------------------------------------------------------------------------

template <typename Index, typename Alpha,
         typename TA, typename TB,
         typename Beta,
         typename TC>
void
gemm(Index m, Index n, Index k,
     Alpha alpha,
     const TA *A, Index incRowA, Index incColA,
     const TB *B, Index incRowB, Index incColB,
     Beta beta,
     TC *C, Index incRowC, Index incColC)
{
    typedef typename std::common_type<Alpha, TA, TB>::type  T;

#if defined(GLOBAL_THREAD_POOL)
    int nof_threads = ::GLOBAL_THREAD_POOL.get_num_threads();
#elif defined(_OPENMP)
    // int nof_threads = omp_get_num_threads();
    int nof_threads;
    #pragma omp parallel
    {
        if (omp_get_thread_num() == 0) {
            nof_threads = omp_get_num_threads();
        }
    }
#else
    int nof_threads = 1;
#endif

    const Index MC = BlockSize<T>::MC * nof_threads;
    const Index NC = BlockSize<T>::NC;
    const Index KC = BlockSize<T>::KC;

    const Index MR = BlockSize<T>::MR;
    const Index NR = BlockSize<T>::NR;

    const Index mb = (m+MC-1) / MC;
    const Index nb = (n+NC-1) / NC;
    const Index kb = (k+KC-1) / KC;

    const Index mc_ = m % MC;
    const Index nc_ = n % NC;
    const Index kc_ = k % KC;

    T *A_ = malloc<T>(MC*KC + MR);
    T *B_ = malloc<T>(KC*NC + NR);

    if (alpha==Alpha(0) || k==0) {
        gescal(m, n, beta, C, incRowC, incColC);
        return;
    }

    for (Index j=0; j<nb; ++j) {
        Index nc = (j!=nb-1 || nc_==0) ? NC : nc_;

        for (Index l=0; l<kb; ++l) {
            Index   kc  = (l!=kb-1 || kc_==0) ? KC : kc_;
            Beta beta_  = (l==0) ? beta : Beta(1);

            pack_B(kc, nc,
                   &B[l*KC*incRowB+j*NC*incColB],
                   incRowB, incColB,
                   B_);

            for (Index i=0; i<mb; ++i) {
                Index mc = (i!=mb-1 || mc_==0) ? MC : mc_;

                pack_A(mc, kc,
                       &A[i*MC*incRowA+l*KC*incColA],
                       incRowA, incColA,
                       A_);

                mgemm(mc, nc, kc,
                      T(alpha), A_, B_, beta_,
                      &C[i*MC*incRowC+j*NC*incColC],
                      incRowC, incColC);
            }
        }
    }
    free(A_);
    free(B_);
}

} } // namespace ulmblas, hpc

#endif // HPC_ULMBLAS_GEMM_H