#ifndef GEMM1_HPP #define GEMM1_HPP #include #include #include #include /* compute C <- alpha A B + beta C */ template __global__ void gemm_kernel(const Alpha alpha, const MA A, const MB B, const Beta beta, MC C) { using Index = typename std::common_type::type; using T = typename std::common_type::type; Index i = threadIdx.x + blockIdx.x * blockDim.x; Index j = threadIdx.y + blockIdx.y * blockDim.y; if (i < C.numRows && j < C.numCols) { T sum{}; for (Index k = 0; k < A.numCols; ++k) { sum += A(i, k) * B(k, j); } sum *= alpha; sum += beta * C(i, j); C(i, j) = sum; } } template void cuda_gemm(const Alpha alpha, const MA& A, const MB& B, const Beta beta, MC& C) { using Index = typename std::common_type::type; assert(A.numRows == C.numRows && A.numCols == B.numRows && B.numCols == C.numCols); constexpr Index blockdim = 16; dim3 block(blockdim, blockdim); Index M = C.numRows; Index N = C.numCols; Index K = A.numCols; using namespace hpc::aux; dim3 grid(ceildiv(M, blockdim), ceildiv(N, blockdim)); gemm_kernel<<>>(alpha, A(0, 0, M, K), B(0, 0, K, N), beta, C(0, 0, M, N)); } #endif