template<typename T,
template<typename> class MA,
template<typename> class MB,
Require<Ge<MA<T>>, Ge<MB<T>>> = true>
int scatter_by_row(const MA<T>& A, MB<T>& B, int root, MPI_Comm comm) {
int nof_processes; MPI_Comm_size(comm, &nof_processes);
int rank; MPI_Comm_rank(comm, &rank);
MPI_Datatype rowtype_B = get_row_type(B);
if (rank == root) {
assert(A.numCols() == B.numCols());
hpc::aux::UniformSlices<int> slices(nof_processes, A.numRows());
std::vector<int> counts(nof_processes);
std::vector<int> offsets(nof_processes);
MPI_Datatype rowtype_A = get_row_type(A);
for (int i = 0; i < nof_processes; ++i) {
if (i < A.numRows()) {
counts[i] = slices.size(i);
offsets[i] = slices.offset(i);
} else {
counts[i] = 0; offsets[i] = 0;
}
}
int recvcount = counts[rank];
assert(B.numRows() == recvcount);
return MPI_Scatterv(
&A(0, 0), &counts[0], &offsets[0], rowtype_A,
&B(0, 0), recvcount, rowtype_B, root, comm);
} else {
int recvcount = B.numRows();
return MPI_Scatterv(nullptr, nullptr, nullptr, nullptr,
&B(0, 0), recvcount, rowtype_B, root, comm);
}
}
template<typename T,
template<typename> class MA,
template<typename> class MB,
Require<Ge<MA<T>>, Ge<MB<T>>> = true>
int gather_by_row(const MA<T>& A, MB<T>& B, int root, MPI_Comm comm) {
int nof_processes; MPI_Comm_size(comm, &nof_processes);
int rank; MPI_Comm_rank(comm, &rank);
MPI_Datatype rowtype_A = get_row_type(A);
if (rank == root) {
assert(A.numCols() == B.numCols());
hpc::aux::UniformSlices<int> slices(nof_processes, B.numRows());
std::vector<int> counts(nof_processes);
std::vector<int> offsets(nof_processes);
for (int i = 0; i < nof_processes; ++i) {
if (i < B.numRows()) {
counts[i] = slices.size(i);
offsets[i] = slices.offset(i);
} else {
counts[i] = 0; offsets[i] = 0;
}
}
int sendcount = counts[rank];
assert(A.numRows() == sendcount);
MPI_Datatype rowtype_B = get_row_type(B);
return MPI_Gatherv(
&A(0, 0), sendcount, rowtype_A,
&B(0, 0), &counts[0], &offsets[0], rowtype_B, root, comm);
} else {
int sendcount = A.numRows();
return MPI_Gatherv((void*) &A(0, 0), sendcount, rowtype_A,
nullptr, nullptr, nullptr, nullptr, root, comm);
}
}