#ifndef GEMM2_HPP #define GEMM2_HPP #include #include #include #include constexpr std::size_t BLOCK_DIM = 16; /* compute C = alpha A B + beta C */ template __global__ void gemm_kernel(const Alpha alpha, const MA A, const MB B, const Beta beta, MC C) { using Index = typename std::common_type::type; using T = typename std::common_type::type; Index i = threadIdx.y + blockIdx.y * BLOCK_DIM; Index j = threadIdx.x + blockIdx.x * BLOCK_DIM; __shared__ T ablock[BLOCK_DIM][BLOCK_DIM]; __shared__ T bblock[BLOCK_DIM][BLOCK_DIM]; Index K = A.numCols; Index rounds = (K + BLOCK_DIM - 1) / BLOCK_DIM; /* we assume A & B to be in row-major */ T sum{}; for (Index round = 0; round < rounds; ++round) { T val; if (i < A.numRows && round*BLOCK_DIM + threadIdx.x < A.numCols) { val = A(i, round*BLOCK_DIM + threadIdx.x); } else { val = 0; } ablock[threadIdx.y][threadIdx.x] = val; if (round*BLOCK_DIM + threadIdx.x < B.numRows && j < B.numCols) { val = B(round*BLOCK_DIM + threadIdx.y, j); } else { val = 0; } bblock[threadIdx.y][threadIdx.x] = val; __syncthreads(); #pragma unroll for (Index k = 0; k < BLOCK_DIM; ++k) { sum += ablock[threadIdx.y][k] * bblock[k][threadIdx.x]; } __syncthreads(); } if (i < C.numRows && j < C.numCols) { C(i, j) = sum; } } template void cuda_gemm(const Alpha alpha, const MA& A, const MB& B, const Beta beta, MC& C) { using Index = typename std::common_type::type; assert(A.numRows == C.numRows && A.numCols == B.numRows && B.numCols == C.numCols); dim3 block(BLOCK_DIM, BLOCK_DIM); Index M = C.numRows; Index N = C.numCols; Index K = A.numCols; using namespace hpc::aux; dim3 grid(ceildiv(M, BLOCK_DIM), ceildiv(N, BLOCK_DIM)); gemm_kernel<<>>(alpha, A(0, 0, M, K), B(0, 0, K, N), beta, C(0, 0, M, N)); } #endif