#ifndef HPC_MPI_MATRIX_HPP #define HPC_MPI_MATRIX_HPP 1 #include #include #include #include #include #include #include #include namespace hpc { namespace mpi { /* construct MPI data type for a row of Matrix A where the extent is adapted such that consecutive rows can be combined even if they overlap */ template class Matrix, Require>> = true> MPI_Datatype get_row_type(const Matrix& A) { MPI_Datatype rowtype; MPI_Type_vector( /* count = */ A.numCols(), /* blocklength = */ 1, /* stride = */ A.incCol(), /* element type = */ get_type(A(0, 0)), /* newly created type = */ &rowtype); /* in case of row-major we are finished */ if (A.incRow() == A.numCols()) { MPI_Type_commit(&rowtype); return rowtype; } /* the extent of the MPI data type does not match the offset of subsequent rows -- this is a problem whenever we want to handle more than one row; to fix this we need to use the resize function which allows us to adapt the extent to A.incRow() */ MPI_Datatype resized_rowtype; MPI_Type_create_resized(rowtype, 0, /* lb remains unchanged */ A.incRow() * sizeof(T), &resized_rowtype); MPI_Type_commit(&resized_rowtype); MPI_Type_free(&rowtype); return resized_rowtype; } /* create MPI data type for matrix A */ template class Matrix, Require>> = true> MPI_Datatype get_type(const Matrix& A) { MPI_Datatype datatype; if (A.incCol() == 1) { MPI_Type_vector( /* count = */ A.numRows(), /* blocklength = */ A.numCols(), /* stride = */ A.incRow(), /* element type = */ get_type(A(0, 0)), /* newly created type = */ &datatype); } else { /* vector of row vectors */ MPI_Datatype rowtype = get_row_type(A); MPI_Type_contiguous(A.numRows(), rowtype, &datatype); MPI_Type_free(&rowtype); } MPI_Type_commit(&datatype); return datatype; } template class MA, template class MB, Require>> = true, Require>> = true> int scatter_by_row(const MA& A, MB& B, int root, MPI_Comm comm) { int nof_processes; MPI_Comm_size(comm, &nof_processes); int rank; MPI_Comm_rank(comm, &rank); MPI_Datatype rowtype_B = get_row_type(B); int rval; if (rank == root) { assert(A.numCols() == B.numCols()); hpc::aux::UniformSlices slices(nof_processes, A.numRows()); std::vector counts(nof_processes); std::vector offsets(nof_processes); MPI_Datatype rowtype_A = get_row_type(A); for (int i = 0; i < nof_processes; ++i) { if (i < A.numRows()) { counts[i] = slices.size(i); offsets[i] = slices.offset(i); } else { counts[i] = 0; offsets[i] = 0; } } int recvcount = counts[rank]; assert(B.numRows() == recvcount); /* aged OpenMPI implementations of Debian wheezy and jessie expect void* instead of const void*; hence we need to remove const */ rval = MPI_Scatterv( const_cast(reinterpret_cast(&A(0, 0))), &counts[0], &offsets[0], rowtype_A, &B(0, 0), recvcount, rowtype_B, root, comm); MPI_Type_free(&rowtype_A); } else { int recvcount = B.numRows(); rval = MPI_Scatterv(nullptr, nullptr, nullptr, nullptr, &B(0, 0), recvcount, rowtype_B, root, comm); } MPI_Type_free(&rowtype_B); return rval; } template class MA, template class MB, Require>> = true, Require>> = true> int gather_by_row(const MA& A, MB& B, int root, MPI_Comm comm) { int nof_processes; MPI_Comm_size(comm, &nof_processes); int rank; MPI_Comm_rank(comm, &rank); MPI_Datatype rowtype_A = get_row_type(A); int rval; if (rank == root) { assert(A.numCols() == B.numCols()); hpc::aux::UniformSlices slices(nof_processes, B.numRows()); std::vector counts(nof_processes); std::vector offsets(nof_processes); for (int i = 0; i < nof_processes; ++i) { if (i < B.numRows()) { counts[i] = slices.size(i); offsets[i] = slices.offset(i); } else { counts[i] = 0; offsets[i] = 0; } } int sendcount = counts[rank]; assert(A.numRows() == sendcount); MPI_Datatype rowtype_B = get_row_type(B); /* aged OpenMPI implementations of Debian wheezy and jessie expect void* instead of const void*; hence we need to remove const */ rval = MPI_Gatherv( const_cast(reinterpret_cast(&A(0, 0))), sendcount, rowtype_A, &B(0, 0), &counts[0], &offsets[0], rowtype_B, root, comm); MPI_Type_free(&rowtype_B); } else { int sendcount = A.numRows(); rval = MPI_Gatherv((void*) &A(0, 0), sendcount, rowtype_A, nullptr, nullptr, nullptr, nullptr, root, comm); } MPI_Type_free(&rowtype_A); return rval; } } } // namespaces mpi, hpc #endif // HPC_MPI_MATRIX_H