#include <ulmblas/level3.h>
#include <ulmblas/level2.h>
#include <ulmblas/level1.h>
#include <stdlib.h>
#include <stdint.h>
void *
malloc_(size_t alignment, size_t size)
{
    size     += alignment;
    void *ptr  = malloc(size);
    void *ptr2 = (void *)(((size_t)ptr + alignment) & ~(alignment-1));
    void **vp  = (void**) ptr2 - 1;
    *vp        = ptr;
    return ptr2;
}
void
free_(void *ptr)
{
    free(*((void**)ptr-1));
}
#define DGEMM_MC 256
#define DGEMM_NC 512
#define DGEMM_KC 256
#define DGEMM_MR 4
#define DGEMM_NR 8
void
dpack_A(int m, int k, const double *A, int incRowA, int incColA, double *p)
{
    int i, i0, j, l, nu;
    int mp = (m+DGEMM_MR-1) / DGEMM_MR;
    for (j=0; j<k; ++j) {
        for (l=0; l<mp; ++l) {
            for (i0=0; i0<DGEMM_MR; ++i0) {
                i     = l*DGEMM_MR + i0;
                nu    = l*DGEMM_MR*k + j*DGEMM_MR + i0;
                p[nu] = (i<m) ? A[i*incRowA+j*incColA]
                              : 0;
            }
        }
    }
}
void
dpack_B(int k, int n, const double *B, int incRowB, int incColB, double *p)
{
    int i, j, j0, l, nu;
    int np = (n+DGEMM_NR-1) / DGEMM_NR;
    for (l=0; l<np; ++l) {
        for (j0=0; j0<DGEMM_NR; ++j0) {
            j = l*DGEMM_NR + j0;
            for (i=0; i<k; ++i) {
                nu    = l*DGEMM_NR*k + i*DGEMM_NR + j0;
                p[nu] = (j<n) ? B[i*incRowB+j*incColB]
                              : 0;
            }
        }
    }
}
void
ugemm_4_8(int64_t k,
          double alpha,
          const double *A,
          const double *B,
          double beta,
          double *C, int64_t incRowC, int64_t incColC);
void
dgemm_macro(int m, int n, int k,
            double alpha,
            const double *A,
            const double *B,
            double beta,
            double *C, int incRowC, int incColC)
{
    double C_[DGEMM_MR*DGEMM_NR];
    int i, j;
    const int MR = DGEMM_MR;
    const int NR = DGEMM_NR;
    for (j=0; j<n; j+=NR) {
        int nr = (j+NR<n) ? NR
                          : n - j;
        for (i=0; i<m; i+=MR) {
            int mr = (i+MR<m) ? MR
                              : m - i;
            if (mr==MR && nr==NR) {
                ugemm_4_8(k, alpha,
                          &A[i*k], &B[j*k],
                          beta,
                          &C[i*incRowC+j*incColC], incRowC, incColC);
            } else {
                ugemm_4_8(k, alpha,
                          &A[i*k], &B[j*k],
                          0.,
                          C_, 1, MR);
                dgescal(mr, nr, beta,
                        &C[i*incRowC+j*incColC], incRowC, incColC);
                dgeaxpy(mr, nr, 1., C_, 1, MR,
                        &C[i*incRowC+j*incColC], incRowC, incColC);
            }
        }
    }
}
void
dgemm(int m, int n, int k,
      double alpha,
      const double *A, int incRowA, int incColA,
      const double *B, int incRowB, int incColB,
      double beta,
      double *C, int incRowC, int incColC)
{
    int i, j, l;
    const int MC = DGEMM_MC;
    const int NC = DGEMM_NC;
    const int KC = DGEMM_KC;
    if (alpha==0.0 || k==0) {
        dgescal(m, n, beta, C, incRowC, incColC);
    } else {
        double *A_ = (double *)malloc_(32, sizeof(double)*DGEMM_MC*DGEMM_KC*2);
        double *B_ = (double *)malloc_(32, sizeof(double)*DGEMM_KC*DGEMM_NC*2);
        for (j=0; j<n; j+=NC) {
            int nc = (j+NC<=n) ? NC
                               : n - j;
            for (l=0; l<k; l+=KC) {
                int    kc    = (l+KC<=k) ? KC
                                         : k - l;
                double beta_ = (l==0) ? beta
                                      : 1;
                dpack_B(kc, nc,
                        &B[l*incRowB+j*incColB], incRowB, incColB,
                        B_);
                for (i=0; i<m; i+=MC) {
                    int mc = (i+MC<=m) ? MC
                                       : m - i;
                    dpack_A(mc, kc,
                            &A[i*incRowA+l*incColA], incRowA, incColA,
                            A_);
                    dgemm_macro(mc, nc, kc,
                                alpha,
                                A_, B_,
                                beta_,
                                &C[i*incRowC+j*incColC], incRowC, incColC);
                }
            }
        }
        free_(A_);
        free_(B_);
    }
}