monolish  0.14.0
MONOlithic LIner equation Solvers for Highly-parallel architecture
dense-dense_matmul.cpp
Go to the documentation of this file.
1 #include "../../../../include/monolish_blas.hpp"
2 #include "../../../internal/monolish_internal.hpp"
3 
4 namespace monolish {
5 
6 // double ///////////////////
9  Logger &logger = Logger::get_instance();
10  logger.func_in(monolish_func);
11 
12  // err
13  assert(A.get_col() == B.get_row());
14  assert(A.get_row() == C.get_row());
15  assert(B.get_col() == C.get_col());
16  assert(util::is_same_device_mem_stat(A, B, C));
17 
18  const double *Ad = A.val.data();
19  const double *Bd = B.val.data();
20  double *Cd = C.val.data();
21 
22  // MN = MK * KN
23  const size_t m = A.get_row();
24  const size_t n = B.get_col();
25  const size_t k = A.get_col();
26  const double alpha = 1.0;
27  const double beta = 0.0;
28 
29  if (A.get_device_mem_stat() == true) {
30 #if MONOLISH_USE_GPU
31  cublasHandle_t h;
32  internal::check_CUDA(cublasCreate(&h));
33 #pragma omp target data use_device_ptr(Ad, Bd, Cd)
34  {
35  // cublas is col major
36  internal::check_CUDA(cublasDgemm(h, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
37  &alpha, Bd, n, Ad, k, &beta, Cd, n));
38  }
39  cublasDestroy(h);
40 #else
41  throw std::runtime_error("error USE_GPU is false, but gpu_status == true");
42 #endif
43  } else {
44  cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, Ad,
45  k, Bd, n, beta, Cd, n);
46  }
47  logger.func_out();
48 }
49 
50 // float ///////////////////
53  Logger &logger = Logger::get_instance();
54  logger.func_in(monolish_func);
55 
56  // err
57  assert(A.get_col() == B.get_row());
58  assert(A.get_row() == C.get_row());
59  assert(B.get_col() == C.get_col());
60  assert(util::is_same_device_mem_stat(A, B, C));
61 
62  const float *Ad = A.val.data();
63  const float *Bd = B.val.data();
64  float *Cd = C.val.data();
65 
66  // MN = MK * KN
67  const size_t m = A.get_row();
68  const size_t n = B.get_col();
69  const size_t k = A.get_col();
70  const float alpha = 1.0;
71  const float beta = 0.0;
72 
73  if (A.get_device_mem_stat() == true) {
74 #if MONOLISH_USE_GPU
75  cublasHandle_t h;
76  internal::check_CUDA(cublasCreate(&h));
77 #pragma omp target data use_device_ptr(Ad, Bd, Cd)
78  {
79  // cublas is col major
80  internal::check_CUDA(cublasSgemm(h, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
81  &alpha, Bd, n, Ad, k, &beta, Cd, n));
82  }
83  cublasDestroy(h);
84 #else
85  throw std::runtime_error("error USE_GPU is false, but gpu_status == true");
86 #endif
87  } else {
88  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, Ad,
89  k, Bd, n, beta, Cd, n);
90  }
91 
92  logger.func_out();
93 }
94 } // namespace monolish
monolish_func
#define monolish_func
Definition: monolish_logger.hpp:9
monolish::Logger
logger class (singleton, for developper class)
Definition: monolish_logger.hpp:19
monolish::Logger::func_out
void func_out()
Definition: logger_utils.cpp:80
monolish::matrix::Dense::get_device_mem_stat
bool get_device_mem_stat() const
true: sended, false: not send
Definition: monolish_dense.hpp:377
monolish::matrix::Dense::get_row
size_t get_row() const
get # of row
Definition: monolish_dense.hpp:199
monolish::matrix::Dense::val
std::vector< Float > val
Dense format value(size M x N)
Definition: monolish_dense.hpp:47
monolish::matrix::Dense
Dense format Matrix.
Definition: monolish_coo.hpp:28
monolish::util::is_same_device_mem_stat
bool is_same_device_mem_stat(const T &arg1, const U &arg2)
compare same device memory status
Definition: monolish_common.hpp:431
monolish
Definition: monolish_matrix_blas.hpp:9
monolish::matrix::Dense::get_col
size_t get_col() const
get # of col
Definition: monolish_dense.hpp:208
monolish::blas::matmul
void matmul(const matrix::Dense< double > &A, const matrix::Dense< double > &B, matrix::Dense< double > &C)
Dense matrix multiplication: C = AB.
Definition: dense-dense_matmul.cpp:7
monolish::Logger::get_instance
static Logger & get_instance()
Definition: monolish_logger.hpp:42
monolish::Logger::func_in
void func_in(const std::string func_name)
Definition: logger_utils.cpp:69