Possible solution

Content

Source code

#ifndef HPC_MPI_GEMM_HPP
#define HPC_MPI_GEMM_HPP 1

#include <cassert>
#include <mpi.h>
#include <hpc/matvec/copy.hpp>
#include <hpc/matvec/gematrix.hpp>
#include <hpc/matvec/mm.hpp>
#include <hpc/mpi/matrix.hpp>
#include "gematrix.hpp"

namespace hpc { namespace mpi {

template <typename T>
void
mm(T alpha, const GeMatrix<T> &A, const GeMatrix<T> &B, T beta, GeMatrix<T> &C)
{
    auto grid        = A.grid;

    assert(A.numCols == B.numRows);
    assert(C.numRows == A.numRows);
    assert(C.numCols == B.numCols);

    std::size_t k    = A.numCols;
    std::size_t bs   = 1;

    assert(C.numLocalRows(grid.nodeRow) == A.numLocalRows(grid.nodeRow));
    assert(C.numLocalCols(grid.nodeCol) == B.numLocalCols(grid.nodeCol));
    assert(C.rowOffset(grid.nodeRow) == A.rowOffset(grid.nodeRow));
    assert(C.colOffset(grid.nodeCol) == B.colOffset(grid.nodeCol));

    auto i0 = A.rowOffset(grid.nodeRow);
    auto j0 = B.colOffset(grid.nodeCol);

    auto m0 = A.numLocalRows(grid.nodeRow);
    auto n0 = B.numLocalCols(grid.nodeCol);

    hpc::matvec::GeMatrix<T>    A_(m0, bs);
    hpc::matvec::GeMatrix<T>    B_(bs, n0);

    auto typeA_ = get_type(A_);
    auto typeB_ = get_type(B_);


    for (std::size_t l=0; l<k; ++l) {

        int rootA;
        for (int c=0; c<grid.numNodeCols; ++c) {
            auto l0 = A.colOffset(c);
            auto l1 = l0 + A.numLocalCols(c);

            if (l >= l0 && l < l1) {
                hpc::matvec::copy(A.buffer.block(0, l-l0).dim(m0, bs), A_);
                rootA = c;
            }
        }
        MPI_Bcast(&A_(0,0), 1, typeA_, rootA, grid.commRow);

        int rootB;
        for (int r=0; r<grid.numNodeRows; ++r) {
            auto l0 = B.rowOffset(r);
            auto l1 = l0 + B.numLocalRows(r);

            if (l >= l0 && l < l1) {
                hpc::matvec::copy(B.buffer.block(l-l0,0).dim(bs, n0), B_);
                rootB = r;
            }
        }
        MPI_Bcast(&B_(0,0), 1, typeB_, rootB, grid.commCol);

        auto beta_ = l==0 ? beta : 1;

        hpc::matvec::mm(alpha, A_, B_, beta_, C.buffer);
    }
}

} } // namespaces mpi, hpc

#endif // HPC_MPI_GEMM_HPP

Test run

theon$ mpic++ -g -std=c++17 -I. -I/home/numerik/pub/hpc/ws19/session27 -o test_gemm test_gemm.cpp
theon$ mpirun -np 4 test_gemm
A =
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00    10.00    11.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00   110.00   111.00
   200.00   201.00   202.00   203.00   204.00   205.00   206.00   207.00   208.00   209.00   210.00   211.00
   300.00   301.00   302.00   303.00   304.00   305.00   306.00   307.00   308.00   309.00   310.00   311.00
   400.00   401.00   402.00   403.00   404.00   405.00   406.00   407.00   408.00   409.00   410.00   411.00
   500.00   501.00   502.00   503.00   504.00   505.00   506.00   507.00   508.00   509.00   510.00   511.00
   600.00   601.00   602.00   603.00   604.00   605.00   606.00   607.00   608.00   609.00   610.00   611.00
   700.00   701.00   702.00   703.00   704.00   705.00   706.00   707.00   708.00   709.00   710.00   711.00

B =
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00
    10.00    11.00    12.00    13.00    14.00    15.00    16.00    17.00    18.00    19.00
    20.00    21.00    22.00    23.00    24.00    25.00    26.00    27.00    28.00    29.00
    30.00    31.00    32.00    33.00    34.00    35.00    36.00    37.00    38.00    39.00
    40.00    41.00    42.00    43.00    44.00    45.00    46.00    47.00    48.00    49.00
    50.00    51.00    52.00    53.00    54.00    55.00    56.00    57.00    58.00    59.00
    60.00    61.00    62.00    63.00    64.00    65.00    66.00    67.00    68.00    69.00
    70.00    71.00    72.00    73.00    74.00    75.00    76.00    77.00    78.00    79.00
    80.00    81.00    82.00    83.00    84.00    85.00    86.00    87.00    88.00    89.00
    90.00    91.00    92.00    93.00    94.00    95.00    96.00    97.00    98.00    99.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00
   110.00   111.00   112.00   113.00   114.00   115.00   116.00   117.00   118.00   119.00

C = 
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00
   200.00   201.00   202.00   203.00   204.00   205.00   206.00   207.00   208.00   209.00
   300.00   301.00   302.00   303.00   304.00   305.00   306.00   307.00   308.00   309.00
   400.00   401.00   402.00   403.00   404.00   405.00   406.00   407.00   408.00   409.00
   500.00   501.00   502.00   503.00   504.00   505.00   506.00   507.00   508.00   509.00
   600.00   601.00   602.00   603.00   604.00   605.00   606.00   607.00   608.00   609.00
   700.00   701.00   702.00   703.00   704.00   705.00   706.00   707.00   708.00   709.00

C <- 1.000000 * A*B + 0.000000*C
Cref =
  5060.00  5126.00  5192.00  5258.00  5324.00  5390.00  5456.00  5522.00  5588.00  5654.00
 71060.00 72326.00 73592.00 74858.00 76124.00 77390.00 78656.00 79922.00 81188.00 82454.00
 137060.00 139526.00 141992.00 144458.00 146924.00 149390.00 151856.00 154322.00 156788.00 159254.00
 203060.00 206726.00 210392.00 214058.00 217724.00 221390.00 225056.00 228722.00 232388.00 236054.00
 269060.00 273926.00 278792.00 283658.00 288524.00 293390.00 298256.00 303122.00 307988.00 312854.00
 335060.00 341126.00 347192.00 353258.00 359324.00 365390.00 371456.00 377522.00 383588.00 389654.00
 401060.00 408326.00 415592.00 422858.00 430124.00 437390.00 444656.00 451922.00 459188.00 466454.00
 467060.00 475526.00 483992.00 492458.00 500924.00 509390.00 517856.00 526322.00 534788.00 543254.00

Cmpi =
  5060.00  5126.00  5192.00  5258.00  5324.00  5390.00  5456.00  5522.00  5588.00  5654.00
 71060.00 72326.00 73592.00 74858.00 76124.00 77390.00 78656.00 79922.00 81188.00 82454.00
 137060.00 139526.00 141992.00 144458.00 146924.00 149390.00 151856.00 154322.00 156788.00 159254.00
 203060.00 206726.00 210392.00 214058.00 217724.00 221390.00 225056.00 228722.00 232388.00 236054.00
 269060.00 273926.00 278792.00 283658.00 288524.00 293390.00 298256.00 303122.00 307988.00 312854.00
 335060.00 341126.00 347192.00 353258.00 359324.00 365390.00 371456.00 377522.00 383588.00 389654.00
 401060.00 408326.00 415592.00 422858.00 430124.00 437390.00 444656.00 451922.00 459188.00 466454.00
 467060.00 475526.00 483992.00 492458.00 500924.00 509390.00 517856.00 526322.00 534788.00 543254.00

Cref - Cmpi =
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
theon$ 
heim$ OMPI_CXX=g++-8.3 mpic++ -g -std=c++17 -I. -I/home/numerik/pub/hpc/ws19/session27 -o test_gemm test_gemm.cpp
heim$ mpirun -np 4 test_gemm
A =
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00    10.00    11.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00   110.00   111.00
   200.00   201.00   202.00   203.00   204.00   205.00   206.00   207.00   208.00   209.00   210.00   211.00
   300.00   301.00   302.00   303.00   304.00   305.00   306.00   307.00   308.00   309.00   310.00   311.00
   400.00   401.00   402.00   403.00   404.00   405.00   406.00   407.00   408.00   409.00   410.00   411.00
   500.00   501.00   502.00   503.00   504.00   505.00   506.00   507.00   508.00   509.00   510.00   511.00
   600.00   601.00   602.00   603.00   604.00   605.00   606.00   607.00   608.00   609.00   610.00   611.00
   700.00   701.00   702.00   703.00   704.00   705.00   706.00   707.00   708.00   709.00   710.00   711.00

B =
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00
    10.00    11.00    12.00    13.00    14.00    15.00    16.00    17.00    18.00    19.00
    20.00    21.00    22.00    23.00    24.00    25.00    26.00    27.00    28.00    29.00
    30.00    31.00    32.00    33.00    34.00    35.00    36.00    37.00    38.00    39.00
    40.00    41.00    42.00    43.00    44.00    45.00    46.00    47.00    48.00    49.00
    50.00    51.00    52.00    53.00    54.00    55.00    56.00    57.00    58.00    59.00
    60.00    61.00    62.00    63.00    64.00    65.00    66.00    67.00    68.00    69.00
    70.00    71.00    72.00    73.00    74.00    75.00    76.00    77.00    78.00    79.00
    80.00    81.00    82.00    83.00    84.00    85.00    86.00    87.00    88.00    89.00
    90.00    91.00    92.00    93.00    94.00    95.00    96.00    97.00    98.00    99.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00
   110.00   111.00   112.00   113.00   114.00   115.00   116.00   117.00   118.00   119.00

C = 
     0.00     1.00     2.00     3.00     4.00     5.00     6.00     7.00     8.00     9.00
   100.00   101.00   102.00   103.00   104.00   105.00   106.00   107.00   108.00   109.00
   200.00   201.00   202.00   203.00   204.00   205.00   206.00   207.00   208.00   209.00
   300.00   301.00   302.00   303.00   304.00   305.00   306.00   307.00   308.00   309.00
   400.00   401.00   402.00   403.00   404.00   405.00   406.00   407.00   408.00   409.00
   500.00   501.00   502.00   503.00   504.00   505.00   506.00   507.00   508.00   509.00
   600.00   601.00   602.00   603.00   604.00   605.00   606.00   607.00   608.00   609.00
   700.00   701.00   702.00   703.00   704.00   705.00   706.00   707.00   708.00   709.00

C <- 1.000000 * A*B + 0.000000*C
Cref =
  5060.00  5126.00  5192.00  5258.00  5324.00  5390.00  5456.00  5522.00  5588.00  5654.00
 71060.00 72326.00 73592.00 74858.00 76124.00 77390.00 78656.00 79922.00 81188.00 82454.00
 137060.00 139526.00 141992.00 144458.00 146924.00 149390.00 151856.00 154322.00 156788.00 159254.00
 203060.00 206726.00 210392.00 214058.00 217724.00 221390.00 225056.00 228722.00 232388.00 236054.00
 269060.00 273926.00 278792.00 283658.00 288524.00 293390.00 298256.00 303122.00 307988.00 312854.00
 335060.00 341126.00 347192.00 353258.00 359324.00 365390.00 371456.00 377522.00 383588.00 389654.00
 401060.00 408326.00 415592.00 422858.00 430124.00 437390.00 444656.00 451922.00 459188.00 466454.00
 467060.00 475526.00 483992.00 492458.00 500924.00 509390.00 517856.00 526322.00 534788.00 543254.00

Cmpi =
  5060.00  5126.00  5192.00  5258.00  5324.00  5390.00  5456.00  5522.00  5588.00  5654.00
 71060.00 72326.00 73592.00 74858.00 76124.00 77390.00 78656.00 79922.00 81188.00 82454.00
 137060.00 139526.00 141992.00 144458.00 146924.00 149390.00 151856.00 154322.00 156788.00 159254.00
 203060.00 206726.00 210392.00 214058.00 217724.00 221390.00 225056.00 228722.00 232388.00 236054.00
 269060.00 273926.00 278792.00 283658.00 288524.00 293390.00 298256.00 303122.00 307988.00 312854.00
 335060.00 341126.00 347192.00 353258.00 359324.00 365390.00 371456.00 377522.00 383588.00 389654.00
 401060.00 408326.00 415592.00 422858.00 430124.00 437390.00 444656.00 451922.00 459188.00 466454.00
 467060.00 475526.00 483992.00 492458.00 500924.00 509390.00 517856.00 526322.00 534788.00 543254.00

Cref - Cmpi =
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00     0.00
heim$