#ifndef LU_HPP
#define LU_HPP 1
#include <boost/numeric/ublas/operation.hpp>
#include <boost/numeric/ublas/vector_proxy.hpp>
#include <boost/numeric/ublas/matrix_proxy.hpp>
#include <boost/numeric/ublas/vector.hpp>
#include <boost/numeric/ublas/triangular.hpp>
#include "gemm.hpp"
namespace foo {
template <typename MP, typename MA>
void
swap_rows(const MP &P, MA &A, bool reverse=false)
{
typedef typename MP::size_type size_type;
size_type size = P.size();
if (!reverse) {
for (size_type i=0; i<size; ++i) {
if (i!=P(i)) {
row(A, i).swap(row(A, P(i)));
}
}
} else {
for (size_type I=size; I>=1; --I) {
size_type i = I-1;
if (i!=P(i)) {
row(A, i).swap(row(A, P(i)));
}
}
}
}
// Unblocked LU factorization with partial pivoting
template <typename MA, typename MP>
typename MA::size_type
lu_unblocked(MA &A, MP &P)
{
namespace ublas = boost::numeric::ublas;
using ublas::range;
using ublas::matrix_column;
using ublas::matrix_row;
typedef typename MA::size_type size_type;
typedef typename MA::value_type value_type;
size_type singular = 0;
size_type m = A.size1 ();
size_type n = A.size2 ();
size_type mn = std::min(m, n);
for (size_type i=0; i<mn; ++ i) {
matrix_column<MA> A_i(column(A,i));
matrix_row<MA> Ai_(row(A,i));
P(i) = i + index_norm_inf(project(A_i, range(i, m)));
if (A(P(i),i) != value_type()) {
if (P(i)!=i) {
row(A, P(i)).swap (Ai_);
}
} else {
singular = i+1;
}
value_type alpha = value_type(1)/A(i,i);
project (A_i, range(i+1, m)) *= alpha;
project (A, range(i+1, m), range (i+1, n)).minus_assign (
outer_prod (project(A_i, range(i+1, m)),
project(Ai_, range(i+1, n))));
}
return singular;
}
// Blocked LU factorization with partial pivoting
template <typename MA, typename MP>
typename MA::size_type
lu_blocked(MA &A, MP &P)
{
namespace ublas = boost::numeric::ublas;
using ublas::range;
using ublas::matrix_column;
using ublas::matrix_row;
typedef typename MA::size_type size_type;
typedef typename MA::value_type value_type;
size_type singular = 0;
size_type singular_ = 0;
size_type m = A.size1 ();
size_type n = A.size2 ();
size_type mn = std::min(m, n);
size_type bs = 64;
if (bs>=mn) {
singular = lu_unblocked(A, P);
} else {
for (size_type j=0; j<mn; j+=bs) {
auto jb = std::min(mn-j, bs);
auto A_ = project(A, range(j,m), range(j,j+jb));
auto P_ = project(P, range(j,m));
singular_ = lu_unblocked(A_, P_);
if (singular==0 && singular_>0) {
singular = singular_ + j;
}
auto A_left = project(A, range(j,m), range(0,j));
foo::swap_rows(project(P, range(j,j+jb)), A_left);
if (j+jb<=n) {
auto A_right = project(A, range(j,m), range(j+jb,n));
foo::swap_rows(project(P, range(j,j+jb)), A_right);
const auto L = project(A, range(j,j+jb), range(j,j+jb));
auto U_right = project(A, range(j,j+jb), range(j+jb,n));
//inplace_solve(L, U_right, ublas::unit_lower_tag());
trlsm(value_type(1), true, L, U_right);
if (j+jb<=m) {
auto A_ = project(A, range(j+jb,m), range(j+jb,n));
gemm(value_type(-1),
project(A, range(j+jb,m), range(j,j+jb)),
project(A, range(j,j+jb), range(j+jb,n)),
value_type(1),
A_);
}
}
for (size_type i=j; i<std::min(m, j+jb); ++i) {
P(i) += j;
}
}
}
return singular;
}
// Blocked recursive LU factorization with partial pivoting
template <typename MA, typename MP>
typename MA::size_type
lu_blocked_recursive(MA &A, MP &P)
{
namespace ublas = boost::numeric::ublas;
using ublas::range;
using ublas::matrix_column;
using ublas::matrix_row;
typedef typename MA::size_type size_type;
typedef typename MA::value_type value_type;
size_type singular = 0;
size_type singular_ = 0;
size_type m = A.size1();
size_type n = A.size2();
size_type mn = std::min(m, n);
size_type bs = 8;
if (bs>=mn) {
singular = lu_unblocked(A, P);
} else {
size_type k;
for (k=1; k<mn/4; k*=2);
auto A_left = project(A, range(0,m), range(0,k));
singular_ = lu_blocked_recursive(A_left, P);
if (singular==0 && singular_>0) {
singular = singular_;
}
auto A_right = project(A, range(0,m), range(k,n));
auto mk = std::min(m, k);
foo::swap_rows(project(P, range(0,mk)), A_right);
const auto L = project(A, range(0,mk), range(0,mk));
auto U_right = project(A, range(0,mk), range(mk,n));
//inplace_solve(L, U_right, ublas::unit_lower_tag());
trlsm(value_type(1), true, L, U_right);
auto A_ = project(A, range(mk,m), range(mk,n));
auto P_ = project(P, range(mk,m));
gemm(value_type(-1),
project(A, range(mk,m), range(0,mk)),
project(A, range(0,mk), range(mk,n)),
value_type(1),
A_);
//std::cout << std::endl;
//std::cout << "M = " << m << ", N = " << n << ", K = " << k << std::endl;
//std::cout << "m = " << (m-mk) << ", n = " << (n-mk) << ", k = " << mk << std::endl;
singular_ = lu_blocked_recursive(A_, P_);
if (singular==0 && singular_>0) {
singular = singular_ + mk;
}
auto A_left_bottom = project(A, range(mk,m), range(0,k));
foo::swap_rows(project(P, range(mk,mn)), A_left_bottom);
for (size_type i=mk; i<mn; ++i) {
P(i) += mk;
}
}
return singular;
}
} // namespace foo
#endif // LU_HPP