1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
#ifndef FLENS_EXAMPLES_CG_BLAS_H
#define FLENS_EXAMPLES_CG_BLAS_H 1

#include <flens/flens.h>
#include <limits>

template <typename MA, typename VX, typename VB>
int
cg(const MA &A, const VB &b, VX &&x,
   double tol = std::numeric_limits<double>::epsilon(),
   int    maxIterations = std::numeric_limits<int>::max())
{
    using namespace flens;

    typedef typename VB::ElementType  ElementType;
    typedef typename VB::IndexType    IndexType;
    typedef typename VB::NoView       VectorType;

    ElementType  alpha, beta, rNormSquare, rNormSquarePrev;
    VectorType   Ap, r, p;

    const ElementType  Zero(0), One(1);

///
/// `r = b - A*x;`
///
    blas::copy(b, r);
    blas::mv(NoTrans, -One, A, x, One, r);

///
/// `p = r;`
///
    blas::copy(r, p);

///
/// `rNormSquare = r*r;`
///
    rNormSquare = blas::dot(r, r);

    for (int k=1; k<=maxIterations; ++k) {
        std::cout << "k = " << k << std::endl;
        if (sqrt(rNormSquare)<=tol) {
            return k-1;
        }

///
///     `Ap = A*p;`
///
        blas::mv(NoTrans, One, A, p, Zero, Ap);

///
///     `alpha = rNormSquare/(p * Ap);`
///
        alpha = rNormSquare/blas::dot(p, Ap);

///
///     `x += alpha*p;`
///
        blas::axpy(alpha, p, x);

///
///     `r -= alpha*Ap;`
///
        blas::axpy(-alpha, Ap, r);

        rNormSquarePrev = rNormSquare;

///
///     `rNormSquare = r*r;`
///
        rNormSquare = blas::dot(r, r);

        beta = rNormSquare/rNormSquarePrev;

///
///     `p = beta*p + r;`
///
        blas::scal(beta, p);
        blas::axpy(One, r, p);
    }
    return maxIterations;
}

#endif // FLENS_EXAMPLES_CG_BLAS_H