#ifndef HPC_CUDA_AGGREGATE_HPP
#define HPC_CUDA_AGGREGATE_HPP 1
#ifdef __CUDACC__
#include <hpc/cuda/check.hpp>
#include <hpc/cuda/gematrix.hpp>
#include <hpc/cuda/properties.hpp>
namespace hpc { namespace cuda {
namespace aggregate_impl {
/* we expect our GPU to support BLOCK_DIM x BLOCK_DIM threads per block;
   this must be a power of 2 and is best a multiply of the warp size */
constexpr std::size_t BLOCK_DIM = 32;
/* kernel that operates with a thread for each row
   of each block of matrix A */
template<
   template<typename> class MatrixA,
   template<typename> class MatrixB,
   typename T,
   typename Aggregator,
   Require<
      DeviceGe<MatrixA<T>>, DeviceView<MatrixA<T>>,
      DeviceGe<MatrixB<T>>, DeviceView<MatrixB<T>>
   > = true
>
__global__ void aggregate_row_wise(const MatrixA<T> A, MatrixB<T> B,
      Aggregator aggregator) {
   std::size_t i = threadIdx.x + blockIdx.x * BLOCK_DIM;
   std::size_t j = threadIdx.y + blockIdx.y * BLOCK_DIM;
   if (i < A.numRows()) {
      T val = A(i, j);
      std::size_t maxoffset = BLOCK_DIM;
      if (j + maxoffset >= A.numCols()) {
	 maxoffset = A.numCols() - j;
      }
      for (std::size_t offset = 1; offset < maxoffset; ++offset) {
	 val = aggregator(val, A(i, j + offset));
      }
      B(i, blockIdx.y) = val;
   }
}
/* kernel that operates with a thread for each column
   of each block of matrix A */
template<
   template<typename> class MatrixA,
   template<typename> class MatrixB,
   typename T,
   typename Aggregator,
   Require<
      DeviceGe<MatrixA<T>>, DeviceView<MatrixA<T>>,
      DeviceGe<MatrixB<T>>, DeviceView<MatrixB<T>>
   > = true
>
__global__ void aggregate_col_wise(const MatrixA<T> A, MatrixB<T> B,
      Aggregator aggregator) {
   std::size_t i = threadIdx.x + blockIdx.x * BLOCK_DIM;
   std::size_t j = threadIdx.y + blockIdx.y * BLOCK_DIM;
   if (j < A.numCols()) {
      T val = A(i, j);
      std::size_t maxoffset = BLOCK_DIM;
      if (i + maxoffset >= A.numRows()) {
	 maxoffset = A.numRows() - i;
      }
      for (std::size_t offset = 1; offset < maxoffset; ++offset) {
	 val = aggregator(val, A(i + offset, j));
      }
      B(blockIdx.x, j) = val;
   }
}
/* kernel that operates with a thread for each element of matrix A;
   note that this kernel function is expected to be configured
   with blockDim.x and blockDim.y to be powers of 2 */
template<
   template<typename> class MatrixA,
   template<typename> class MatrixB,
   typename T,
   typename Aggregator,
   Require<
      DeviceGe<MatrixA<T>>, DeviceView<MatrixA<T>>,
      DeviceGe<MatrixB<T>>, DeviceView<MatrixB<T>>
   > = true
>
__global__ void aggregate2d(const MatrixA<T> A, MatrixB<T> B,
      Aggregator aggregator) {
   std::size_t i = threadIdx.x + blockIdx.x * blockDim.x;
   std::size_t j = threadIdx.y + blockIdx.y * blockDim.y;
   __shared__ T local_block[BLOCK_DIM][BLOCK_DIM];
   std::size_t me_i = threadIdx.x;
   std::size_t me_j = threadIdx.y;
   if (i < A.numRows() && j < A.numCols()) {
      local_block[me_i][me_j] = A(i, j);
   }
   std::size_t active_i = blockDim.x / 2;
   std::size_t active_j = blockDim.y / 2;
   while (active_i > 0) {
      __syncthreads();
      if (me_i < active_i && i + active_i < A.numRows()) {
	 local_block[me_i][me_j] = aggregator(local_block[me_i][me_j],
	    local_block[me_i + active_i][me_j]);
      }
      active_i /= 2;
   }
   while (active_j > 0) {
      __syncthreads();
      if (me_j < active_j && j + active_j < A.numCols()) {
	 local_block[me_i][me_j] = aggregator(local_block[me_i][me_j],
	    local_block[me_i][me_j + active_j]);
      }
      active_j /= 2;
   }
   if (me_i == 0 && me_j == 0) {
      B(blockIdx.x, blockIdx.y) = local_block[0][0];
   }
}
template<
   template<typename> class Matrix,
   typename T,
   Require< DeviceGe<Matrix<T>> > = true
>
auto transposed_view(const Matrix<T>& A) {
   return DeviceGeMatrixConstView<T>(A.numCols(), A.numRows(), A.conj(),
      A.data(), A.incCol(), A.incRow());
}
template<
   template<typename> class Matrix,
   typename T,
   Require< DeviceGe<Matrix<T>> > = true
>
auto row_major_view(const Matrix<T>& A) {
   if (A.incRow() < A.incCol()) {
      return transposed_view(A);
   } else {
      return A.view();
   }
}
template<
   template<typename> class Matrix,
   typename T,
   Require< DeviceGe<Matrix<T>> > = true
>
auto col_major_view(const Matrix<T>& A) {
   if (A.incRow() > A.incCol()) {
      return transposed_view(A);
   } else {
      return A.view();
   }
}
inline constexpr std::size_t ceildiv(std::size_t x, std::size_t y) {
   /* note that we expect x > 0 && y > 0;
      not safe against overflows but we expect y to be small */
   return (x + y - 1) / y;
}
template<typename T>
inline constexpr T align_to_power_of_2(T val) {
   T res = 1;
   while (res < val) {
      res *= 2;
   }
   return res;
}
} // namespace aggregate_impl
template<
   template<typename> class Matrix,
   typename T,
   typename Aggregator,
   Require< DeviceGe<Matrix<T>> > = true
>
T aggregate(const Matrix<T>& A, Aggregator&& aggregator) {
   using namespace aggregate_impl;
   if (A.numRows() > 1 || A.numCols() > 1) {
      if (A.numRows() <= BLOCK_DIM && A.numCols() <= BLOCK_DIM) {
	 /* most efficiently handled by a trivial 2d aggregator */
	 DeviceGeMatrix<T> B(1, 1);
	 const auto source = col_major_view(A);
	 dim3 block_dim(align_to_power_of_2(source.numRows()),
	    align_to_power_of_2(source.numCols()));
	 auto target = B.view();
	 aggregate2d<<<1, block_dim>>>(source, target, aggregator);
	 return aggregate(target, std::move(aggregator));
      } else {
	 /* aggregate column- or row-wise */
	 if (A.numRows() > A.numCols()) {
	    auto source = row_major_view(A);
	    dim3 grid_dim(ceildiv(source.numRows(), BLOCK_DIM),
	       ceildiv(source.numCols(), BLOCK_DIM));
	    DeviceGeMatrix<T> B(ceildiv(source.numRows(), BLOCK_DIM),
	       source.numCols());
	    auto target = B.view();
	    dim3 block_dim(1, BLOCK_DIM);
	    aggregate_col_wise<<<grid_dim, block_dim>>>(source, target,
	       aggregator);
	    return aggregate(target, std::move(aggregator));
	 } else {
	    auto source = col_major_view(A);
	    dim3 grid_dim(ceildiv(source.numRows(), BLOCK_DIM),
	       ceildiv(source.numCols(), BLOCK_DIM));
	    DeviceGeMatrix<T> B(source.numRows(),
	       ceildiv(source.numCols(), BLOCK_DIM));
	    auto target = B.view();
	    dim3 block_dim(BLOCK_DIM, 1);
	    aggregate_row_wise<<<grid_dim, block_dim>>>(source, target,
	       aggregator);
	    return aggregate(target, std::move(aggregator));
	 }
      }
   } else {
      T result;
      CHECK_CUDA(cudaMemcpy, &result, A.data(), sizeof(T),
	 cudaMemcpyDeviceToHost);
      return result;
   }
}
} } // namespaces cuda and hpc
#else
#	error This CUDA source must be compiled using nvcc
#endif
#endif