11#ifndef TLAPACK_STARPU_TASKS_HH
12#define TLAPACK_STARPU_TASKS_HH
18 constexpr double gemm(
double m,
double n,
double k)
22 constexpr double trsm(
double m,
double n) {
return m * m * n; }
23 constexpr double herk(
double n,
double k) {
return (n + 1) * n * k; }
24 constexpr double chol(
double n) {
return (n / 3) * n * n; }
31 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
32 void insert_task_gemm(
Op transA,
40 using args_t = std::tuple<Op, Op, alpha_t, beta_t>;
47 struct starpu_task* task = starpu_task_create();
50 args_t* args_ptr =
new args_t;
53 std::get<0>(*args_ptr) = transA;
54 std::get<1>(*args_ptr) = transB;
55 std::get<2>(*args_ptr) = alpha;
56 std::get<3>(*args_ptr) = beta;
59 starpu_data_handle_t handle[3];
60 C.create_compatible_inout_handles(handle, A, B);
64 (
struct starpu_codelet*)&(cl::gemm<TA, TB, TC, alpha_t, beta_t>);
65 task->handles[0] = handle[1];
66 task->handles[1] = handle[2];
67 task->handles[2] = handle[0];
68 task->cl_arg = (
void*)args_ptr;
69 task->cl_arg_size =
sizeof(args_t);
70 task->callback_func = [](
void* args)
noexcept {
delete (args_t*)args; };
71 task->callback_arg = (
void*)args_ptr;
73 flops::gemm(C.m, C.n, (transA == Op::NoTrans ? A.n : A.m));
76 const int ret = starpu_task_submit(task);
77 STARPU_CHECK_RETURN_VALUE(ret,
"starpu_task_submit");
80 C.clean_compatible_inout_handles(handle, A, B);
83 template <
class TA,
class TC,
class alpha_t,
class beta_t>
84 void insert_task_herk(
Uplo uplo,
91 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
98 struct starpu_task* task = starpu_task_create();
101 args_t* args_ptr =
new args_t;
104 std::get<0>(*args_ptr) = uplo;
105 std::get<1>(*args_ptr) = trans;
106 std::get<2>(*args_ptr) = alpha;
107 std::get<3>(*args_ptr) = beta;
110 starpu_data_handle_t handle[2];
114 task->cl = (
struct starpu_codelet*)&(cl::herk<TA, TC, alpha_t, beta_t>);
115 task->handles[0] = handle[0];
116 task->handles[1] = handle[1];
117 task->cl_arg = (
void*)args_ptr;
118 task->cl_arg_size =
sizeof(args_t);
119 task->callback_func = [](
void* args)
noexcept {
delete (args_t*)args; };
120 task->callback_arg = (
void*)args_ptr;
121 task->flops = flops::herk(C.m, (trans == Op::NoTrans ? A.n : A.m));
124 const int ret = starpu_task_submit(task);
125 STARPU_CHECK_RETURN_VALUE(ret,
"starpu_task_submit");
131 template <
class TA,
class TB,
class alpha_t>
132 void insert_task_trsm(
Side side,
136 const alpha_t& alpha,
140 using args_t = std::tuple<Side, Uplo, Op, Diag, alpha_t>;
147 struct starpu_task* task = starpu_task_create();
150 args_t* args_ptr =
new args_t;
153 std::get<0>(*args_ptr) = side;
154 std::get<1>(*args_ptr) = uplo;
155 std::get<2>(*args_ptr) = trans;
156 std::get<3>(*args_ptr) =
diag;
157 std::get<4>(*args_ptr) = alpha;
160 starpu_data_handle_t handle[2];
164 task->cl = (
struct starpu_codelet*)&(cl::trsm<TA, TB, alpha_t>);
165 task->handles[0] = handle[0];
166 task->handles[1] = handle[1];
167 task->cl_arg = (
void*)args_ptr;
168 task->cl_arg_size =
sizeof(args_t);
169 task->callback_func = [](
void* args)
noexcept {
delete (args_t*)args; };
170 task->callback_arg = (
void*)args_ptr;
171 task->flops = flops::trsm(A.m, ((side == Side::Left) ? B.n : B.m));
174 const int ret = starpu_task_submit(task);
175 STARPU_CHECK_RETURN_VALUE(ret,
"starpu_task_submit");
181 template <
class uplo_t,
class T>
182 void insert_task_potrf(uplo_t uplo,
184 starpu_data_handle_t info =
nullptr)
186 using args_t = std::tuple<uplo_t>;
187 constexpr bool use_cusolver = cuda::is_cusolver_v<T>;
193 const bool has_info = (info !=
nullptr);
196 struct starpu_task* task = starpu_task_create();
199 args_t* args_ptr =
new args_t;
202 std::get<0>(*args_ptr) = uplo;
205 task->cl = (
struct starpu_codelet*)&(
206 has_info ? cl::potrf<uplo_t, T> : cl::potrf_noinfo<uplo_t, T>);
207 task->handles[0] = A.handle;
208 if (has_info) task->handles[1] = info;
209 task->cl_arg = (
void*)args_ptr;
210 task->cl_arg_size =
sizeof(args_t);
211 task->callback_func = [](
void* args)
noexcept {
delete (args_t*)args; };
212 task->callback_arg = (
void*)args_ptr;
213 task->flops = flops::chol(A.m);
215 if constexpr (use_cusolver) {
217 if (starpu_cuda_worker_get_count() > 0) {
218#ifdef STARPU_HAVE_LIBCUSOLVER
219 const cublasFillMode_t uplo_ = cuda::uplo2cublas(uplo);
220 const int n = starpu_matrix_get_nx(A.handle);
222 if constexpr (is_same_v<T, float>) {
223 cusolverDnSpotrf_bufferSize(
224 starpu_cusolverDn_get_local_handle(), uplo_, n,
nullptr,
226 lwork *=
sizeof(float);
228 else if constexpr (is_same_v<T, double>) {
229 cusolverDnDpotrf_bufferSize(
230 starpu_cusolverDn_get_local_handle(), uplo_, n,
nullptr,
232 lwork *=
sizeof(double);
234 else if constexpr (is_same_v<real_type<T>,
float>) {
235 cusolverDnCpotrf_bufferSize(
236 starpu_cusolverDn_get_local_handle(), uplo_, n,
nullptr,
238 lwork *=
sizeof(cuFloatComplex);
240 else if constexpr (is_same_v<real_type<T>,
double>) {
241 cusolverDnZpotrf_bufferSize(
242 starpu_cusolverDn_get_local_handle(), uplo_, n,
nullptr,
244 lwork *=
sizeof(cuDoubleComplex);
247 static_assert(
sizeof(T) == 0,
248 "Type not supported in cuSolver");
251 starpu_variable_data_register(&(task->handles[(has_info ? 2 : 1)]),
256 const int ret = starpu_task_submit(task);
257 STARPU_CHECK_RETURN_VALUE(ret,
"starpu_task_submit");
259 if constexpr (use_cusolver)
260 starpu_data_unregister_submit(task->handles[(has_info ? 2 : 1)]);
Codelets for StarPU tasks.
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.
static void clean_compatible_handles(starpu_data_handle_t handles[2], const Tile &A, const Tile &B) noexcept
Clean the partition created by create_compatible_handles()
Definition Tile.hpp:132
static void create_compatible_handles(starpu_data_handle_t handles[2], const Tile &A, const Tile &B) noexcept
Create a compatible handles between two tiles.
Definition Tile.hpp:97