#include <cstdlib> #include <deque> #include <mutex> #include <random> #include <thread> #include <vector> #include <hpc/aux/slices.hpp> #include <hpc/matvec/gematrix.hpp> #include <hpc/matvec/iterators.hpp> #include <hpc/matvec/print.hpp> using namespace hpc; using namespace hpc::aux; using namespace hpc::matvec; template<typename T> struct RandomEnginePool { using EngineType = T; T get() { /* check if we have a free engine in the unused deque */ { std::lock_guard<std::mutex> lock(mutex); if (unused.size() > 0) { T rg = std::move(unused.front()); unused.pop_front(); return rg; } } /* prepare new random generator */ return T(r()); } void free(T&& engine) { std::lock_guard<std::mutex> lock(mutex); unused.push_back(engine); } private: std::mutex mutex; std::random_device r; std::deque<T> unused; }; template<typename T> struct RandomEngineGuard { using EngineType = T; RandomEngineGuard(RandomEnginePool<T>& pool) : pool(pool), engine(pool.get()) { } ~RandomEngineGuard() { pool.free(std::move(engine)); } T& get() { return engine; } RandomEnginePool<T>& pool; T engine; }; template < template<typename> class MatrixA, typename T, typename POOL, Require< Ge<MatrixA<T>> > = true > void randomInit(MatrixA<T>& A, POOL& pool) { using EngineType = typename POOL::EngineType; RandomEngineGuard<EngineType> guard(pool); std::uniform_real_distribution<double> uniform(-100, 100); auto& engine = guard.get(); for (auto [i, j, Aij]: A) { Aij = uniform(engine); (void) i; (void) j; // suppress gcc warning } } int main() { RandomEnginePool<std::mt19937> pool; GeMatrix<double> A(51, 7); std::size_t nof_threads = std::thread::hardware_concurrency(); std::vector<std::thread> threads(nof_threads); UniformSlices<std::size_t> slices(nof_threads, A.numRows()); for (std::size_t index = 0; index < nof_threads; ++index) { auto firstRow = slices.offset(index); auto numRows = slices.size(index); threads[index] = std::thread([ A_ = A.view(firstRow, 0, numRows, A.numCols()), &pool ]() mutable { randomInit(A_, pool); }); } for (std::size_t index = 0; index < nof_threads; ++index) { threads[index].join(); } print(A, " %7.2f"); } |