11#ifndef TLAPACK_STARPU_BLAS_CPU_HH
12#define TLAPACK_STARPU_BLAS_CPU_HH
14#include <starpu_cublas_v2.h>
15#include <starpu_cusolver.h>
17#include "tlapack/legacy_api/blas.hpp"
34 constexpr void gemm(
void** buffers,
void* args)
noexcept
36 using args_t = std::tuple<Op, Op, alpha_t, beta_t>;
39 const args_t& cl_args = *(args_t*)args;
40 const Op& transA = std::get<0>(cl_args);
41 const Op& transB = std::get<1>(cl_args);
42 const alpha_t& alpha = std::get<2>(cl_args);
43 const beta_t& beta = std::get<3>(cl_args);
46 const idx_t& m = STARPU_MATRIX_GET_NX(buffers[2]);
47 const idx_t& n = STARPU_MATRIX_GET_NY(buffers[2]);
48 const idx_t& k = (transA == Op::NoTrans)
49 ? STARPU_MATRIX_GET_NY(buffers[0])
50 : STARPU_MATRIX_GET_NX(buffers[0]);
51 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
52 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
53 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[2]);
56 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
57 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
58 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[2]);
61 if constexpr (mode == 0) {
62 using T = scalar_type<TC, beta_t>;
63 legacy::gemm(Layout::ColMajor, transA, transB, m, n, k, alpha,
64 (
const TA*)A, lda, (
const TB*)B, ldb, (T)beta,
68 else if constexpr (mode == 1) {
69 using T = scalar_type<TA, TB, TC, alpha_t, beta_t>;
71 const cublasOperation_t opA = cuda::op2cublas(transA);
72 const cublasOperation_t opB = cuda::op2cublas(transB);
73 const T alpha_ = (T)alpha;
74 const T beta_ = (T)beta;
75 cublasStatus_t status;
77 if constexpr (is_same_v<T, float>)
78 status = cublasSgemm(starpu_cublas_get_local_handle(), opA,
79 opB, m, n, k, &alpha_, (
const float*)A,
80 lda, (
const float*)B, ldb, &beta_,
82 else if constexpr (is_same_v<T, double>)
84 starpu_cublas_get_local_handle(), opA, opB, m, n, k,
85 &alpha_, (
const double*)A, lda, (
const double*)B, ldb,
86 &beta_, (
double*)C, ldc);
87 else if constexpr (is_same_v<real_type<T>,
float>)
89 starpu_cublas_get_local_handle(), opA, opB, m, n, k,
90 (
const cuFloatComplex*)&alpha_,
91 (
const cuFloatComplex*)A, lda, (
const cuFloatComplex*)B,
92 ldb, (
const cuFloatComplex*)&beta_, (cuFloatComplex*)C,
94 else if constexpr (is_same_v<real_type<T>,
double>)
96 cublasZgemm(starpu_cublas_get_local_handle(), opA, opB,
97 m, n, k, (
const cuDoubleComplex*)&alpha_,
98 (
const cuDoubleComplex*)A, lda,
99 (
const cuDoubleComplex*)B, ldb,
100 (
const cuDoubleComplex*)&beta_,
101 (cuDoubleComplex*)C, ldc);
103 static_assert(
sizeof(T) == 0,
104 "Type not supported in cuBLAS");
106 if (status != CUBLAS_STATUS_SUCCESS)
107 STARPU_CUBLAS_REPORT_ERROR(status);
111 static_assert(mode == 0 || mode == 1,
"Invalid mode");
120 constexpr void symm(
void** buffers,
void* args)
noexcept
122 using args_t = std::tuple<Side, Uplo, alpha_t, beta_t>;
125 const args_t& cl_args = *(args_t*)args;
126 const Side& side = std::get<0>(cl_args);
127 const Uplo& uplo = std::get<1>(cl_args);
128 const alpha_t& alpha = std::get<2>(cl_args);
129 const beta_t& beta = std::get<3>(cl_args);
132 const idx_t& m = STARPU_MATRIX_GET_NX(buffers[2]);
133 const idx_t& n = STARPU_MATRIX_GET_NY(buffers[2]);
134 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
135 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
136 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[2]);
139 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
140 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
141 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[2]);
144 using T = scalar_type<TC, beta_t>;
146 (
const TA*)A, lda, (
const TB*)B, ldb, (T)beta, (TC*)C,
156 constexpr void hemm(
void** buffers,
void* args)
noexcept
158 using args_t = std::tuple<Side, Uplo, alpha_t, beta_t>;
161 const args_t& cl_args = *(args_t*)args;
162 const Side& side = std::get<0>(cl_args);
163 const Uplo& uplo = std::get<1>(cl_args);
164 const alpha_t& alpha = std::get<2>(cl_args);
165 const beta_t& beta = std::get<3>(cl_args);
168 const idx_t& m = STARPU_MATRIX_GET_NX(buffers[2]);
169 const idx_t& n = STARPU_MATRIX_GET_NY(buffers[2]);
170 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
171 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
172 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[2]);
175 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
176 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
177 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[2]);
180 using T = scalar_type<TC, beta_t>;
182 (
const TA*)A, lda, (
const TB*)B, ldb, (T)beta, (TC*)C,
186 template <
class TA,
class TC,
class alpha_t,
class beta_t,
int mode = 0>
187 constexpr void syrk(
void** buffers,
void* args)
noexcept
189 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
192 const args_t& cl_args = *(args_t*)args;
193 const Uplo& uplo = std::get<0>(cl_args);
194 const Op& op = std::get<1>(cl_args);
195 const alpha_t& alpha = std::get<2>(cl_args);
196 const beta_t& beta = std::get<3>(cl_args);
199 const idx_t& n = STARPU_MATRIX_GET_NX(buffers[1]);
200 const idx_t& k = (op == Op::NoTrans)
201 ? STARPU_MATRIX_GET_NY(buffers[0])
202 : STARPU_MATRIX_GET_NX(buffers[0]);
203 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
204 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[1]);
207 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
208 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[1]);
211 using T = scalar_type<TC, beta_t>;
212 legacy::syrk(Layout::ColMajor, uplo, op, n, k, alpha, (
const TA*)A,
213 lda, (T)beta, (TC*)C, ldc);
216 template <
class TA,
class TC,
class alpha_t,
class beta_t,
int mode = 0>
217 constexpr void herk(
void** buffers,
void* args)
noexcept
219 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
222 const args_t& cl_args = *(args_t*)args;
223 const Uplo& uplo = std::get<0>(cl_args);
224 const Op& op = std::get<1>(cl_args);
225 const alpha_t& alpha = std::get<2>(cl_args);
226 const beta_t& beta = std::get<3>(cl_args);
229 const idx_t& n = STARPU_MATRIX_GET_NX(buffers[1]);
230 const idx_t& k = (op == Op::NoTrans)
231 ? STARPU_MATRIX_GET_NY(buffers[0])
232 : STARPU_MATRIX_GET_NX(buffers[0]);
233 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
234 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[1]);
237 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
238 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[1]);
241 if constexpr (mode == 0) {
242 using real_t = real_type<scalar_type<TC, beta_t>>;
244 (
const TA*)A, lda, (real_t)beta, (TC*)C, ldc);
246#ifdef STARPU_USE_CUDA
247 else if constexpr (mode == 1) {
248 using T = scalar_type<TA, TC, alpha_t, beta_t>;
249 using real_t = real_type<T>;
251 const cublasFillMode_t uplo_ = cuda::uplo2cublas(uplo);
252 const cublasOperation_t op_ = cuda::op2cublas(op);
253 const real_t alpha_ = (real_t)alpha;
254 const real_t beta_ = (real_t)beta;
255 cublasStatus_t status;
257 if constexpr (is_same_v<T, float>)
258 status = cublasSsyrk(
259 starpu_cublas_get_local_handle(), uplo_, op_, n, k,
260 &alpha_, (
const float*)A, lda, &beta_, (
float*)C, ldc);
261 else if constexpr (is_same_v<T, double>)
263 cublasDsyrk(starpu_cublas_get_local_handle(), uplo_,
264 op_, n, k, &alpha_, (
const double*)A, lda,
265 &beta_, (
double*)C, ldc);
266 else if constexpr (is_same_v<real_type<T>,
float>)
267 status = cublasCherk(starpu_cublas_get_local_handle(),
268 uplo_, op_, n, k, &alpha_,
269 (
const cuFloatComplex*)A, lda, &beta_,
270 (cuFloatComplex*)C, ldc);
271 else if constexpr (is_same_v<real_type<T>,
double>)
272 status = cublasZherk(starpu_cublas_get_local_handle(),
273 uplo_, op_, n, k, &alpha_,
274 (
const cuDoubleComplex*)A, lda, &beta_,
275 (cuDoubleComplex*)C, ldc);
277 static_assert(
sizeof(T) == 0,
278 "Type not supported in cuBLAS");
280 if (status != CUBLAS_STATUS_SUCCESS)
281 STARPU_CUBLAS_REPORT_ERROR(status);
285 static_assert(mode == 0 || mode == 1,
"Invalid mode");
294 constexpr void syr2k(
void** buffers,
void* args)
noexcept
296 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
299 const args_t& cl_args = *(args_t*)args;
300 const Uplo& uplo = std::get<0>(cl_args);
301 const Op& op = std::get<1>(cl_args);
302 const alpha_t& alpha = std::get<2>(cl_args);
303 const beta_t& beta = std::get<3>(cl_args);
306 const idx_t& n = STARPU_MATRIX_GET_NX(buffers[2]);
307 const idx_t& k = (op == Op::NoTrans)
308 ? STARPU_MATRIX_GET_NY(buffers[0])
309 : STARPU_MATRIX_GET_NX(buffers[0]);
310 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
311 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
312 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[2]);
315 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
316 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
317 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[2]);
320 using T = scalar_type<TC, beta_t>;
321 legacy::syr2k(Layout::ColMajor, uplo, op, n, k, alpha, (
const TA*)A,
322 lda, (
const TB*)B, ldb, (T)beta, (TC*)C, ldc);
331 constexpr void her2k(
void** buffers,
void* args)
noexcept
333 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
336 const args_t& cl_args = *(args_t*)args;
337 const Uplo& uplo = std::get<0>(cl_args);
338 const Op& op = std::get<1>(cl_args);
339 const alpha_t& alpha = std::get<2>(cl_args);
340 const beta_t& beta = std::get<3>(cl_args);
343 const idx_t& n = STARPU_MATRIX_GET_NX(buffers[2]);
344 const idx_t& k = (op == Op::NoTrans)
345 ? STARPU_MATRIX_GET_NY(buffers[0])
346 : STARPU_MATRIX_GET_NX(buffers[0]);
347 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
348 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
349 const idx_t& ldc = STARPU_MATRIX_GET_LD(buffers[2]);
352 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
353 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
354 const uintptr_t& C = STARPU_MATRIX_GET_PTR(buffers[2]);
357 using real_t = real_type<scalar_type<TC, beta_t>>;
358 legacy::her2k(Layout::ColMajor, uplo, op, n, k, alpha, (
const TA*)A,
359 lda, (
const TB*)B, ldb, (real_t)beta, (TC*)C, ldc);
362 template <
class TA,
class TB,
class alpha_t,
int mode = 0>
363 constexpr void trmm(
void** buffers,
void* args)
noexcept
365 using args_t = std::tuple<Side, Uplo, Op, Diag, alpha_t>;
368 const args_t& cl_args = *(args_t*)args;
369 const Side& side = std::get<0>(cl_args);
370 const Uplo& uplo = std::get<1>(cl_args);
371 const Op& op = std::get<2>(cl_args);
372 const Diag&
diag = std::get<3>(cl_args);
373 const alpha_t& alpha = std::get<4>(cl_args);
376 const idx_t& m = STARPU_MATRIX_GET_NX(buffers[1]);
377 const idx_t& n = STARPU_MATRIX_GET_NY(buffers[1]);
378 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
379 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
382 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
383 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
386 legacy::trmm(Layout::ColMajor, side, uplo, op, diag, m, n, alpha,
387 (
const TA*)A, lda, (TB*)B, ldb);
390 template <
class TA,
class TB,
class alpha_t,
int mode = 0>
391 constexpr void trsm(
void** buffers,
void* args)
noexcept
393 using args_t = std::tuple<Side, Uplo, Op, Diag, alpha_t>;
396 const args_t& cl_args = *(args_t*)args;
397 const Side& side = std::get<0>(cl_args);
398 const Uplo& uplo = std::get<1>(cl_args);
399 const Op& op = std::get<2>(cl_args);
400 const Diag&
diag = std::get<3>(cl_args);
401 const alpha_t& alpha = std::get<4>(cl_args);
404 const idx_t& m = STARPU_MATRIX_GET_NX(buffers[1]);
405 const idx_t& n = STARPU_MATRIX_GET_NY(buffers[1]);
406 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
407 const idx_t& ldb = STARPU_MATRIX_GET_LD(buffers[1]);
410 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
411 const uintptr_t& B = STARPU_MATRIX_GET_PTR(buffers[1]);
414 if constexpr (mode == 0)
415 legacy::trsm(Layout::ColMajor, side, uplo, op, diag, m, n,
416 alpha, (
const TA*)A, lda, (TB*)B, ldb);
417#ifdef STARPU_USE_CUDA
418 else if constexpr (mode == 1) {
419 using T = scalar_type<TA, TB, alpha_t>;
421 const cublasSideMode_t side_ = cuda::side2cublas(side);
422 const cublasFillMode_t uplo_ = cuda::uplo2cublas(uplo);
423 const cublasOperation_t op_ = cuda::op2cublas(op);
424 const cublasDiagType_t diag_ = cuda::diag2cublas(diag);
425 const T alpha_ = (T)alpha;
426 cublasStatus_t status;
428 if constexpr (is_same_v<T, float>)
430 cublasStrsm(starpu_cublas_get_local_handle(), side_,
431 uplo_, op_, diag_, m, n, &alpha_,
432 (
const float*)A, lda, (
float*)B, ldb);
433 else if constexpr (is_same_v<T, double>)
435 cublasDtrsm(starpu_cublas_get_local_handle(), side_,
436 uplo_, op_, diag_, m, n, &alpha_,
437 (
const double*)A, lda, (
double*)B, ldb);
438 else if constexpr (is_same_v<real_type<T>,
float>)
439 status = cublasCtrsm(
440 starpu_cublas_get_local_handle(), side_, uplo_, op_,
441 diag_, m, n, (
const cuFloatComplex*)&alpha_,
442 (
const cuFloatComplex*)A, lda, (cuFloatComplex*)B, ldb);
443 else if constexpr (is_same_v<real_type<T>,
double>)
444 status = cublasZtrsm(starpu_cublas_get_local_handle(),
445 side_, uplo_, op_, diag_, m, n,
446 (
const cuDoubleComplex*)&alpha_,
447 (
const cuDoubleComplex*)A, lda,
448 (cuDoubleComplex*)B, ldb);
450 static_assert(
sizeof(T) == 0,
451 "Type not supported in cuBLAS");
453 if (status != CUBLAS_STATUS_SUCCESS)
454 STARPU_CUBLAS_REPORT_ERROR(status);
458 static_assert(mode == 0 || mode == 1,
"Invalid mode");
464 template <
class uplo_t,
class T,
bool has_info,
int mode = 0>
465 constexpr void potrf(
void** buffers,
void* args)
467 using args_t = std::tuple<uplo_t>;
470 const args_t& cl_args = *(args_t*)args;
471 const uplo_t& uplo = std::get<0>(cl_args);
474 const idx_t& n = STARPU_MATRIX_GET_NX(buffers[0]);
475 const idx_t& lda = STARPU_MATRIX_GET_LD(buffers[0]);
478 const uintptr_t& A = STARPU_MATRIX_GET_PTR(buffers[0]);
481 int* info = (has_info) ? (
int*)STARPU_VARIABLE_GET_PTR(buffers[1])
485 if constexpr (mode == 0) {
486 if constexpr (has_info)
491#ifdef STARPU_HAVE_LIBCUSOLVER
492 if constexpr (mode == 1) {
494 STARPU_VARIABLE_GET_PTR(buffers[(has_info ? 2 : 1)]);
495 const size_t& lwork =
496 STARPU_VARIABLE_GET_ELEMSIZE(buffers[(has_info ? 2 : 1)]);
498 const cublasFillMode_t uplo_ = cuda::uplo2cublas(uplo);
499 cusolverStatus_t status = CUSOLVER_STATUS_SUCCESS;
501 if constexpr (is_same_v<T, float>)
502 status = cusolverDnSpotrf(
503 starpu_cusolverDn_get_local_handle(), uplo_, n,
504 (
float*)A, lda, (
float*)w, lwork /
sizeof(float), info);
505 else if constexpr (is_same_v<T, double>)
507 cusolverDnDpotrf(starpu_cusolverDn_get_local_handle(),
508 uplo_, n, (
double*)A, lda, (
double*)w,
509 lwork /
sizeof(double), info);
510 else if constexpr (is_same_v<real_type<T>,
float>)
511 status = cusolverDnCpotrf(
512 starpu_cusolverDn_get_local_handle(), uplo_, n,
513 (cuFloatComplex*)A, lda, (cuFloatComplex*)w,
514 lwork /
sizeof(cuFloatComplex), info);
515 else if constexpr (is_same_v<real_type<T>,
double>)
516 status = cusolverDnZpotrf(
517 starpu_cusolverDn_get_local_handle(), uplo_, n,
518 (cuDoubleComplex*)A, lda, (cuDoubleComplex*)w,
519 lwork /
sizeof(cuDoubleComplex), info);
521 static_assert(
sizeof(T) == 0,
522 "Type not supported in cuSolver");
524 if (status != CUSOLVER_STATUS_SUCCESS)
525 STARPU_CUBLAS_REPORT_ERROR(status);
529 static_assert(mode == 0,
"Invalid mode");
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
void herk(Layout layout, Uplo uplo, Op trans, idx_t n, idx_t k, real_type< TA, TC > alpha, TA const *A, idx_t lda, real_type< TA, TC > beta, TC *C, idx_t ldc)
Hermitian rank-k update:
Definition herk.hpp:87
void syrk(Layout layout, Uplo uplo, Op trans, idx_t n, idx_t k, scalar_type< TA, TC > alpha, TA const *A, idx_t lda, scalar_type< TA, TC > beta, TC *C, idx_t ldc)
Symmetric rank-k update:
Definition syrk.hpp:89
void hemm(Layout layout, Side side, Uplo uplo, idx_t m, idx_t n, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, scalar_type< TA, TB, TC > beta, TC *C, idx_t ldc)
Hermitian matrix-matrix multiply:
Definition hemm.hpp:90
void gemm(Layout layout, Op transA, Op transB, idx_t m, idx_t n, idx_t k, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, scalar_type< TA, TB, TC > beta, TC *C, idx_t ldc)
General matrix-matrix multiply:
Definition gemm.hpp:103
void syr2k(Layout layout, Uplo uplo, Op trans, idx_t n, idx_t k, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, scalar_type< TA, TB, TC > beta, TC *C, idx_t ldc)
Symmetric rank-k update:
Definition syr2k.hpp:101
void her2k(Layout layout, Uplo uplo, Op trans, idx_t n, idx_t k, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, real_type< TA, TB, TC > beta, TC *C, idx_t ldc)
Hermitian rank-k update:
Definition her2k.hpp:100
void symm(Layout layout, Side side, Uplo uplo, idx_t m, idx_t n, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, scalar_type< TA, TB, TC > beta, TC *C, idx_t ldc)
Symmetric matrix-matrix multiply:
Definition symm.hpp:84
void trsm(Layout layout, Side side, Uplo uplo, Op trans, Diag diag, idx_t m, idx_t n, scalar_type< TA, TB > alpha, TA const *A, idx_t lda, TB *B, idx_t ldb)
Solve the triangular matrix-vector equation.
Definition trsm.hpp:102
void trmm(Layout layout, Side side, Uplo uplo, Op trans, Diag diag, idx_t m, idx_t n, scalar_type< TA, TB > alpha, TA const *A, idx_t lda, TB *B, idx_t ldb)
Triangular matrix-matrix multiply:
Definition trmm.hpp:98
int potrf(uplo_t uplo, idx_t n, T *A, idx_t lda)
Computes the Cholesky factorization of a Hermitian positive definite matrix A using a blocked algorit...
Definition potrf.hpp:26
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.