#include <stdlib.h>
#include <math.h>
#include <ulmblas/level1.h>
#include <bench/refblas.h>
#define MIN(X,Y)   ((X)<(Y) ? (X) : (Y))
#define MAX(X,Y)   ((X)>(Y) ? (X) : (Y))
double
err_dgemv(int m, int n, double alpha,
          const double *A, int incRowA, int incColA,
          const double *x, int incX,
          double beta,
          const double *y0, int incY0,
          double *y1, int incY1)
{
    int    max_mn = (m>n) ? m : n;
    double normA  = fabs(alpha)*dgenrm1(m, n, A, incRowA, incColA);
    double normX  = damax(n, x, incX);
    double normY0 = fabs(beta)*damax(m, y0, incY0);
    double normD;
    daxpy(m, -1.0, y0, incY0, y1, incY1);
    normD = damax(n, y1, incY1);
    return normD/(normA*normX*normY0*max_mn);
}
double
err_dtrmv(int n, int unitDiag, int lower,
          const double *A, int incRowA, int incColA,
          const double *x0, int incX0,
          double *x1, int incX1)
{
    double normA  = dtrnrm1(n, n, unitDiag, lower, A, incRowA, incColA);
    double normX0 = damax(n, x0, incX0);
    double normD;
    daxpy(n, -1.0, x0, incX0, x1, incX1);
    normD = damax(n, x1, incX1);
    return normD/(n*normA*normX0);
}
double
err_dtrsv(int n, int unitDiag, int lower,
          const double *A, int incRowA, int incColA,
          const double *x0, int incX0,
          double *x1, int incX1)
{
    double normA  = dtrnrm1(n, n, unitDiag, lower, A, incRowA, incColA);
    double normX0 = damax(n, x0, incX0);
    double normD;
    daxpy(n, -1.0, x0, incX0, x1, incX1);
    normD = damax(n, x1, incX1);
    return normD/(n*normA*normX0);
}
double
err_dger(int m, int n, double alpha,
         const double *x, int incX,
         const double *y, int incY,
         const double *A0, int incRowA0, int incColA0,
         double *A, int incRowA, int incColA)
{
    double normA0 = dgenrm1(m, n, A0, incRowA0, incColA0);
    double normX  = fabs(alpha)*damax(m, x, incX);
    double normY  = damax(n, y, incY);
    double normD;
    int    mn     = (m>n) ? m : n;
    dgeaxpy(m, n, -1.0, A0, incRowA0, incColA0, A, incRowA, incColA);
    normD = dgenrm1(m, n, A, incRowA, incColA);
    return normD/(mn*normA0*normX*normY);
}
double
err_dgemm(int m, int n, int k,
          double alpha,
          const double *A, int incRowA, int incColA,
          const double *B, int incRowB, int incColB,
          double beta,
          const double *C0, int incRowC0, int incColC0,
          double *C, int incRowC, int incColC)
{
    double normA = dgenrm1(m, k, A, incRowA, incColA);
    double normB = dgenrm1(k, n, B, incRowB, incColB);
    double normC = dgenrm1(m, n, C, incRowC0, incColC0);
    double normD;
    int    mn    = (m>n)  ? m  : n;
    int    mnk   = (mn>k) ? mn : k;
    normA = MAX(normA, fabs(alpha)*normA);
    normC = MAX(normC, fabs(beta)*normC);
    dgeaxpy(m, n, -1.0, C0, incRowC0, incColC0, C, incRowC, incColC);
    normD = dgenrm1(m, n, C, incRowC, incColC);
    return normD/(mnk*normA*normB*normC);
}
double
err_lu(int m, int n,
       const double *A, int incRowA, int incColA,
       double *LU, int incRowLU, int incColLU)
{
    double aNrm1 = 0;
    double err   = 0;
    int    i,j;
    int    k     = (m<n) ? m : n;
    double *L    = (double *)malloc(m*k*sizeof(double));
    double *U    = (double *)malloc(k*n*sizeof(double));
    for (j=0; j<k; ++j) {
        for (i=0; i<m; ++i) {
            L[i+j*m] = (i>j) ? LU[i*incRowLU+j*incColLU]
                             : (i==j) ? 1
                                      : 0;
        }
    }
    for (j=0; j<n; ++j) {
        for (i=0; i<k; ++i) {
            U[i+j*k] = (i<=j) ? LU[i*incRowLU+j*incColLU]
                              : 0;
        }
    }
    for (j=0; j<n; ++j) {
        for (i=0; i<m; ++i) {
            aNrm1 += fabs(A[i*incRowA+j*incColA]);
        }
    }
    dgemm_ref(m, n, k, 1.0,
              L, 1, m,
              U, 1, k,
              0.0,
              LU, incRowLU, incColLU);
    for (j=0; j<n; ++j) {
        for (i=0; i<m; ++i) {
            err += fabs(A[i*incRowA+j*incColA]-LU[i*incRowLU+j*incColLU]);
        }
    }
    err /= (aNrm1*k);
    free(L);
    free(U);
    return err;
}