<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
tasks.hpp
Go to the documentation of this file.
1
4//
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_STARPU_TASKS_HH
12#define TLAPACK_STARPU_TASKS_HH
13
15
16namespace tlapack {
17namespace flops {
18 constexpr double gemm(double m, double n, double k)
19 {
20 return 2 * m * n * k;
21 }
22 constexpr double trsm(double m, double n) { return m * m * n; }
23 constexpr double herk(double n, double k) { return (n + 1) * n * k; }
24 constexpr double chol(double n) { return (n / 3) * n * n; }
25} // namespace flops
26} // namespace tlapack
27
28namespace tlapack {
29namespace starpu {
30
31 template <class TA, class TB, class TC, class alpha_t, class beta_t>
32 void insert_task_gemm(Op transA,
33 Op transB,
34 const alpha_t& alpha,
35 const Tile& A,
36 const Tile& B,
37 const beta_t& beta,
38 const Tile& C)
39 {
40 using args_t = std::tuple<Op, Op, alpha_t, beta_t>;
41
42 // check sizes
43 tlapack_check(C.m == (transA == Op::NoTrans ? A.m : A.n));
44 tlapack_check(C.n == (transB == Op::NoTrans ? B.n : B.m));
45
46 // Allocate space for the task
47 struct starpu_task* task = starpu_task_create();
48
49 // Allocate space for the arguments
50 args_t* args_ptr = new args_t;
51
52 // Initialize arguments
53 std::get<0>(*args_ptr) = transA;
54 std::get<1>(*args_ptr) = transB;
55 std::get<2>(*args_ptr) = alpha;
56 std::get<3>(*args_ptr) = beta;
57
58 // Handles
59 starpu_data_handle_t handle[3];
60 C.create_compatible_inout_handles(handle, A, B);
61
62 // Initialize task
63 task->cl =
64 (struct starpu_codelet*)&(cl::gemm<TA, TB, TC, alpha_t, beta_t>);
65 task->handles[0] = handle[1];
66 task->handles[1] = handle[2];
67 task->handles[2] = handle[0];
68 task->cl_arg = (void*)args_ptr;
69 task->cl_arg_size = sizeof(args_t);
70 task->callback_func = [](void* args) noexcept { delete (args_t*)args; };
71 task->callback_arg = (void*)args_ptr;
72 task->flops =
73 flops::gemm(C.m, C.n, (transA == Op::NoTrans ? A.n : A.m));
74
75 // Submit task
76 const int ret = starpu_task_submit(task);
77 STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_submit");
78
79 // Clean partition plan
80 C.clean_compatible_inout_handles(handle, A, B);
81 }
82
83 template <class TA, class TC, class alpha_t, class beta_t>
84 void insert_task_herk(Uplo uplo,
85 Op trans,
86 const alpha_t& alpha,
87 const Tile& A,
88 const beta_t& beta,
89 const Tile& C)
90 {
91 using args_t = std::tuple<Uplo, Op, alpha_t, beta_t>;
92
93 // check sizes
94 tlapack_check(C.m == C.n);
95 tlapack_check(C.m == (trans == Op::NoTrans ? A.m : A.n));
96
97 // Allocate space for the task
98 struct starpu_task* task = starpu_task_create();
99
100 // Allocate space for the arguments
101 args_t* args_ptr = new args_t;
102
103 // Initialize arguments
104 std::get<0>(*args_ptr) = uplo;
105 std::get<1>(*args_ptr) = trans;
106 std::get<2>(*args_ptr) = alpha;
107 std::get<3>(*args_ptr) = beta;
108
109 // Handles
110 starpu_data_handle_t handle[2];
112
113 // Initialize task
114 task->cl = (struct starpu_codelet*)&(cl::herk<TA, TC, alpha_t, beta_t>);
115 task->handles[0] = handle[0];
116 task->handles[1] = handle[1];
117 task->cl_arg = (void*)args_ptr;
118 task->cl_arg_size = sizeof(args_t);
119 task->callback_func = [](void* args) noexcept { delete (args_t*)args; };
120 task->callback_arg = (void*)args_ptr;
121 task->flops = flops::herk(C.m, (trans == Op::NoTrans ? A.n : A.m));
122
123 // Submit task
124 const int ret = starpu_task_submit(task);
125 STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_submit");
126
127 // Clean partition plan
128 Tile::clean_compatible_handles(handle, A, C);
129 }
130
131 template <class TA, class TB, class alpha_t>
132 void insert_task_trsm(Side side,
133 Uplo uplo,
134 Op trans,
135 Diag diag,
136 const alpha_t& alpha,
137 const Tile& A,
138 const Tile& B)
139 {
140 using args_t = std::tuple<Side, Uplo, Op, Diag, alpha_t>;
141
142 // check sizes
143 tlapack_check(A.m == A.n);
144 tlapack_check(A.m == ((side == Side::Left) ? B.m : B.n));
145
146 // Allocate space for the task
147 struct starpu_task* task = starpu_task_create();
148
149 // Allocate space for the arguments
150 args_t* args_ptr = new args_t;
151
152 // Initialize arguments
153 std::get<0>(*args_ptr) = side;
154 std::get<1>(*args_ptr) = uplo;
155 std::get<2>(*args_ptr) = trans;
156 std::get<3>(*args_ptr) = diag;
157 std::get<4>(*args_ptr) = alpha;
158
159 // Handles
160 starpu_data_handle_t handle[2];
162
163 // Initialize task
164 task->cl = (struct starpu_codelet*)&(cl::trsm<TA, TB, alpha_t>);
165 task->handles[0] = handle[0];
166 task->handles[1] = handle[1];
167 task->cl_arg = (void*)args_ptr;
168 task->cl_arg_size = sizeof(args_t);
169 task->callback_func = [](void* args) noexcept { delete (args_t*)args; };
170 task->callback_arg = (void*)args_ptr;
171 task->flops = flops::trsm(A.m, ((side == Side::Left) ? B.n : B.m));
172
173 // Submit task
174 const int ret = starpu_task_submit(task);
175 STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_submit");
176
177 // Clean partition plan
178 Tile::clean_compatible_handles(handle, A, B);
179 }
180
181 template <class uplo_t, class T>
182 void insert_task_potrf(uplo_t uplo,
183 const Tile& A,
184 starpu_data_handle_t info = nullptr)
185 {
186 using args_t = std::tuple<uplo_t>;
187 constexpr bool use_cusolver = cuda::is_cusolver_v<T>;
188
189 // check sizes
190 tlapack_check(A.m == A.n);
191
192 // constants
193 const bool has_info = (info != nullptr);
194
195 // Allocate space for the task
196 struct starpu_task* task = starpu_task_create();
197
198 // Allocate space for the arguments
199 args_t* args_ptr = new args_t;
200
201 // Initialize arguments
202 std::get<0>(*args_ptr) = uplo;
203
204 // Initialize task
205 task->cl = (struct starpu_codelet*)&(
206 has_info ? cl::potrf<uplo_t, T> : cl::potrf_noinfo<uplo_t, T>);
207 task->handles[0] = A.handle;
208 if (has_info) task->handles[1] = info;
209 task->cl_arg = (void*)args_ptr;
210 task->cl_arg_size = sizeof(args_t);
211 task->callback_func = [](void* args) noexcept { delete (args_t*)args; };
212 task->callback_arg = (void*)args_ptr;
213 task->flops = flops::chol(A.m);
214
215 if constexpr (use_cusolver) {
216 int lwork = 0;
217 if (starpu_cuda_worker_get_count() > 0) {
218#ifdef STARPU_HAVE_LIBCUSOLVER
219 const cublasFillMode_t uplo_ = cuda::uplo2cublas(uplo);
220 const int n = starpu_matrix_get_nx(A.handle);
221
222 if constexpr (is_same_v<T, float>) {
223 cusolverDnSpotrf_bufferSize(
224 starpu_cusolverDn_get_local_handle(), uplo_, n, nullptr,
225 n, &lwork);
226 lwork *= sizeof(float);
227 }
228 else if constexpr (is_same_v<T, double>) {
229 cusolverDnDpotrf_bufferSize(
230 starpu_cusolverDn_get_local_handle(), uplo_, n, nullptr,
231 n, &lwork);
232 lwork *= sizeof(double);
233 }
234 else if constexpr (is_same_v<real_type<T>, float>) {
235 cusolverDnCpotrf_bufferSize(
236 starpu_cusolverDn_get_local_handle(), uplo_, n, nullptr,
237 n, &lwork);
238 lwork *= sizeof(cuFloatComplex);
239 }
240 else if constexpr (is_same_v<real_type<T>, double>) {
241 cusolverDnZpotrf_bufferSize(
242 starpu_cusolverDn_get_local_handle(), uplo_, n, nullptr,
243 n, &lwork);
244 lwork *= sizeof(cuDoubleComplex);
245 }
246 else
247 static_assert(sizeof(T) == 0,
248 "Type not supported in cuSolver");
249#endif
250 }
251 starpu_variable_data_register(&(task->handles[(has_info ? 2 : 1)]),
252 -1, 0, lwork);
253 }
254
255 // Submit task
256 const int ret = starpu_task_submit(task);
257 STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_submit");
258
259 if constexpr (use_cusolver)
260 starpu_data_unregister_submit(task->handles[(has_info ? 2 : 1)]);
261 }
262
263} // namespace starpu
264} // namespace tlapack
265
266#endif // TLAPACK_STARPU_TASKS_HH
Codelets for StarPU tasks.
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.
static void clean_compatible_handles(starpu_data_handle_t handles[2], const Tile &A, const Tile &B) noexcept
Clean the partition created by create_compatible_handles()
Definition Tile.hpp:132
static void create_compatible_handles(starpu_data_handle_t handles[2], const Tile &A, const Tile &B) noexcept
Create a compatible handles between two tiles.
Definition Tile.hpp:97