#include <cassert>
#include <chrono>
#include <cmath>
#include <limits>
#include <random>
#include <type_traits>
#include <boost/timer.hpp>
#include <boost/numeric/ublas/matrix.hpp>
#include <boost/numeric/ublas/io.hpp>
#include <boost/numeric/ublas/operation.hpp>
#include "gemm.hpp"
template <typename T>
struct WallTime
{
    void
    tic()
    {
        t0 = std::chrono::high_resolution_clock::now();
    }
    T
    toc()
    {
        using namespace std::chrono;
        elapsed = high_resolution_clock::now() - t0;
        return duration<T,seconds::period>(elapsed).count();
    }
    std::chrono::high_resolution_clock::time_point t0;
    std::chrono::high_resolution_clock::duration   elapsed;
};
// I guess this trait already exists or can be done more elegant ...
template <typename M>
struct MatrixType {
    static constexpr bool  isGeneral    = false;
    static constexpr bool  isSymmetric  = false;
    static constexpr bool  isHermitian  = false;
    static constexpr bool  isTriangular = false;
};
template <typename T, typename SO>
struct MatrixType<boost::numeric::ublas::matrix<T,SO> > {
    static constexpr bool  isGeneral    = true;
    static constexpr bool  isSymmetric  = false;
    static constexpr bool  isHermitian  = false;
    static constexpr bool  isTriangular = false;
};
// fill rectangular matrix with random values
template <typename MATRIX>
typename std::enable_if<MatrixType<MATRIX>::isGeneral,
         void>::type
fill(MATRIX &A)
{
    typedef typename MATRIX::size_type  size_type;
    typedef typename MATRIX::value_type T;
    std::random_device                  random;
    std::default_random_engine          mt(random());
    std::uniform_real_distribution<T>   uniform(-100,100);
    for (size_type i=0; i<A.size1(); ++i) {
        for (size_type j=0; j<A.size2(); ++j) {
            A(i,j) = uniform(mt);
        }
    }
}
template <typename MATRIX>
typename MATRIX::value_type
asum(const MATRIX &A)
{
    typedef typename MATRIX::size_type  size_type;
    typedef typename MATRIX::value_type T;
    T asum = 0;
    for (size_type i=0; i<A.size1(); ++i) {
        for (size_type j=0; j<A.size2(); ++j) {
            asum += std::abs(A(i,j));
        }
    }
    return asum;
}
template <typename MA, typename MB, typename MC0, typename MC1>
double
estimateGemmResidual(const MA &A, const MB &B,
                     const MC0 &C0, const MC1 &C1)
{
    typedef typename MC0::value_type   TC0;
    typedef typename MC0::size_type    size_type;
    size_type m= C1.size1();
    size_type n= C1.size2();
    size_type k= A.size2();
    double aNorm = asum(A);
    double bNorm = asum(B);
    double cNorm = asum(C1);
    double diff  = asum(C1-C0);
    // Using eps for double gives upper bound in case elements have lower
    // precision.
    double eps = std::numeric_limits<double>::epsilon();
    double res = diff/(aNorm*bNorm*cNorm*eps*std::max(std::max(m,n),k));
    return res;
}
namespace foo {
template <typename MATRIXA, typename MATRIXB, typename MATRIXC>
void
axpy_prod(const MATRIXA &A, const MATRIXB &B, MATRIXC &C, bool update)
{
    typedef typename MATRIXA::value_type TA;
    typedef typename MATRIXC::value_type TC;
    assert(A.size2()==B.size1());
    gemm(TA(1), A, B, TC(update ? 0 : 1), C);
}
} // namespace foo
#ifndef M_MAX
#define M_MAX 4000
#endif
#ifndef K_MAX
#define K_MAX 4000
#endif
#ifndef N_MAX
#define N_MAX 4000
#endif
int
main()
{
    namespace ublas = boost::numeric::ublas;
    const std::size_t m_min = 100;
    const std::size_t k_min = 100;
    const std::size_t n_min = 100;
    const std::size_t m_max = M_MAX;
    const std::size_t k_max = K_MAX;
    const std::size_t n_max = N_MAX;
    const std::size_t m_inc = 100;
    const std::size_t k_inc = 100;
    const std::size_t n_inc = 100;
    const bool matprodUpdate = true;
    typedef double              T;
    typedef ublas::row_major    SO;
    std::cout << "#   m";
    std::cout << "     n";
    std::cout << "     k";
    std::cout << "  uBLAS:   t1";
    std::cout << "       MFLOPS";
    std::cout << "   Blocked:   t2";
    std::cout << "      MFLOPS";
    std::cout << "        Diff nrm1";
    std::cout << std::endl;
    WallTime<double>  walltime;
    for (std::size_t m=m_min, k=k_min, n=n_min;
         m<=m_max && k<=k_max && n<=n_max;
         m += m_inc, k += k_inc, n += n_inc)
    {
        ublas::matrix<T,SO>     A(m, k);
        ublas::matrix<T,SO>     B(k, n);
        ublas::matrix<T,SO>     C1(m, n);
        ublas::matrix<T,SO>     C2(m, n);
        fill(A);
        fill(B);
        fill(C1);
        C2 = C1;
        walltime.tic();
        ublas::axpy_prod(A, B, C1, matprodUpdate);
        double t1 = walltime.toc();
        walltime.tic();
        foo::axpy_prod(A, B, C2, matprodUpdate);
        double t2 = walltime.toc();
        double res = estimateGemmResidual(A, B, C1, C2);
        std::cout.width(5);  std::cout << m << " ";
        std::cout.width(5);  std::cout << n << " ";
        std::cout.width(5);  std::cout << k << " ";
        std::cout.width(12); std::cout << t1 << " ";
        std::cout.width(12); std::cout << 2.*m/1000.*n/1000.*k/t1 << " ";
        std::cout.width(15); std::cout << t2 << " ";
        std::cout.width(12); std::cout << 2.*m/1000.*n/1000.*k/t2 << " ";
        std::cout.width(15); std::cout << res;
        std::cout << std::endl;
    }
}