#include <cassert>
#include <cmath>
#include <mpi.h>
#include <printf.hpp>
#include <hpc/aux/hsvcolor.hpp>
#include <hpc/aux/slices.hpp>
#include <hpc/matvec/copy.hpp>
#include <hpc/matvec/gematrix.hpp>
#include <hpc/matvec/iterators.hpp>
#include <hpc/matvec/matrix2pixbuf.hpp>
#include <hpc/mpi/matrix.hpp>
#include <hpc/mpi/vector.hpp>

template<typename T>
const T PI = std::acos(T(-1.0));
template<typename T>
const T E = std::exp(T(1.0));
template<typename T>
const T E_POWER_MINUS_PI = std::pow(E<T>, -PI<T>);

using namespace hpc;

template<typename T, template<typename> class Matrix,
   Require<Ge<Matrix<T>>> = true>
T jacobi_iteration(const Matrix<T>& A, Matrix<T>& B) {
   assert(A.numRows() > 2 && A.numCols() > 2);
   T maxdiff = 0;
   for (std::size_t i = 1; i + 1 < B.numRows(); ++i) {
      for (std::size_t j = 1; j + 1 < B.numCols(); ++j) {
	 B(i, j) = 0.25 *
	    (A(i - 1, j) + A(i + 1, j) + A(i, j - 1) + A(i, j + 1));
	 T diff = std::fabs(A(i, j) - B(i, j));
	 if (diff > maxdiff) maxdiff = diff;
      }
   }
   return maxdiff;
}

template<typename T, template<typename> class Matrix,
   Require<Ge<Matrix<T>>> = true>
void exchange_with_neighbors(Matrix<T>& A,
      /* ranks of the neighbors */
      int previous, int next,
      /* data type for an inner row, i.e. without the border */
      MPI_Datatype rowtype) {
   MPI_Request requests[4]; int request_index = 0;
   MPI_Irecv(&A(0, 1), 1, rowtype, previous, 0,
      MPI_COMM_WORLD, &requests[request_index++]);
   MPI_Irecv(&A(A.numRows()-1, 1), 1, rowtype, next, 0,
      MPI_COMM_WORLD, &requests[request_index++]);
   MPI_Isend(&A(1, 1), 1, rowtype, previous, 0,
      MPI_COMM_WORLD, &requests[request_index++]);
   MPI_Isend(&A(A.numRows()-2, 1), 1, rowtype, next, 0,
      MPI_COMM_WORLD, &requests[request_index++]);

   for (auto& request: requests) {
      MPI_Status status;
      MPI_Wait(&request, &status);
   }
}

int main(int argc, char** argv) {
   MPI_Init(&argc, &argv);

   int nof_processes; MPI_Comm_size(MPI_COMM_WORLD, &nof_processes);
   int rank; MPI_Comm_rank(MPI_COMM_WORLD, &rank);

   using namespace hpc::matvec;
   using namespace hpc::mpi;
   using namespace hpc::aux;
   using T = double;
   using Matrix = GeMatrix<T>;

   /* initialize the entire matrix, including its borders */
   Matrix A(100, 100, Order::RowMajor);
   if (rank == 0) {
      for (auto [i, j, Aij]: A) {
	 if (j == 0) {
	    Aij = std::sin(PI<T> * (T(i)/(A.numRows()-1)));
	 } else if (j == A.numCols() - 1) {
	    Aij = std::sin(PI<T> * (T(i)/(A.numRows()-1))) *
	       E_POWER_MINUS_PI<T>;
	 } else {
	    Aij = 0;
	 }
      }
   }

   /* we use matrices B1 and B2 to work in our set of rows */
   UniformSlices<std::size_t> slices(nof_processes, A.numRows() - 2);
   Matrix B1(slices.size(rank) + 2, A.numCols(), Order::RowMajor);
   for (auto [i, j, Bij]: B1) {
      Bij = 0;
      (void) i; (void) j; // supress gcc warnings
   }
   auto B = B1.block(1, 0).dim(B1.numRows() - 2, B1.numCols());

   /* distribute main body of A include left and right border */
   auto A_ = A.block(1, 0).dim(A.numRows() - 2, A.numCols());
   scatter_by_row(A_, B, 0, MPI_COMM_WORLD);

   /* distribute first and last row of A */
   if (rank == 0) {
      copy(A.block(0, 0).dim(1, A.numCols()),
	 B1.block(0, 0).dim(1, B1.numCols()));
   }
   MPI_Datatype full_rowtype = get_row_type(B);
   if (nof_processes == 1) {
      copy(A.block(A.numRows()-1, 0).dim(1, A.numCols()),
	 B1.block(B.numRows()-1, 0).dim(1, B1.numCols()));
   } else if (rank == 0) {
      MPI_Send(&A(A.numRows()-1, 0), 1, full_rowtype, nof_processes-1, 0,
	 MPI_COMM_WORLD);
   } else if (rank == nof_processes - 1) {
      MPI_Status status;
      MPI_Recv(&B1(B.numRows()-1, 0), 1, full_rowtype,
	 0, 0, MPI_COMM_WORLD, &status);
   }

   Matrix B2(B1.numRows(), B1.numCols(), Order::RowMajor);
   copy(B1, B2); /* actually just the border needs to be copied */

   /* compute type for inner rows without the border */
   auto B_inner = B.block(0, 1).dim(1, A.numCols() - 2);
   MPI_Datatype inner_rowtype = get_type(B_inner);

   int previous = rank == 0? MPI_PROC_NULL: rank-1;
   int next = rank == nof_processes-1? MPI_PROC_NULL: rank+1;

   T eps = 1e-6; unsigned int iterations;
   for (iterations = 0; ; ++iterations) {
      T maxdiff = jacobi_iteration(B1, B2);
      exchange_with_neighbors(B2, previous, next, inner_rowtype);
      maxdiff = jacobi_iteration(B2, B1);
      if (iterations % 10 == 0) {
	 T global_max;
	 MPI_Reduce(&maxdiff, &global_max, 1, get_type(maxdiff),
	    MPI_MAX, 0, MPI_COMM_WORLD);
	 MPI_Bcast(&global_max, 1, get_type(maxdiff), 0, MPI_COMM_WORLD);
	 if (global_max < eps) break;
      }
      exchange_with_neighbors(B1, previous, next, inner_rowtype);
   }
   if (rank == 0) fmt::printf("%d iterations\n", iterations);

   gather_by_row(B, A_, 0, MPI_COMM_WORLD);

   MPI_Finalize();

   if (rank == 0) {
      auto pixbuf = create_pixbuf(A, [](T val) -> HSVColor<float> {
	 return HSVColor<float>((1-val) * 240, 1, 1);
      }, 8);
      gdk_pixbuf_save(pixbuf, "jacobi.jpg", "jpeg", nullptr,
	 "quality", "100", nullptr);
   }
}
