11#ifndef TLAPACK_STARPU_CODELETS_HH
12#define TLAPACK_STARPU_CODELETS_HH
24 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
25 constexpr struct starpu_codelet gen_cl_gemm() noexcept
28 constexpr bool use_cublas =
29 cuda::is_cublas_v<TA, TB, TC, alpha_t, beta_t>;
31 cl.cpu_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t>;
32 if constexpr (use_cublas) {
33 cl.cuda_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t, 1>;
34 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
37 cl.modes[0] = STARPU_R;
38 cl.modes[1] = STARPU_R;
39 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
40 cl.name =
"tlapack::starpu::gemm";
44 cl.where |= STARPU_CPU;
45 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
51 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
52 constexpr struct starpu_codelet gen_cl_symm() noexcept
56 cl.cpu_funcs[0] = func::symm<TA, TB, TC, alpha_t, beta_t>;
58 cl.modes[0] = STARPU_R;
59 cl.modes[1] = STARPU_R;
60 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
61 cl.name =
"tlapack::starpu::symm";
65 cl.where |= STARPU_CPU;
71 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
72 constexpr struct starpu_codelet gen_cl_hemm() noexcept
76 cl.cpu_funcs[0] = func::hemm<TA, TB, TC, alpha_t, beta_t>;
78 cl.modes[0] = STARPU_R;
79 cl.modes[1] = STARPU_R;
80 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
81 cl.name =
"tlapack::starpu::hemm";
85 cl.where |= STARPU_CPU;
91 template <
class TA,
class TC,
class alpha_t,
class beta_t>
92 constexpr struct starpu_codelet gen_cl_syrk() noexcept
96 cl.cpu_funcs[0] = func::syrk<TA, TC, alpha_t, beta_t>;
98 cl.modes[0] = STARPU_R;
99 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
100 cl.name =
"tlapack::starpu::syrk";
104 cl.where |= STARPU_CPU;
110 template <
class TA,
class TC,
class alpha_t,
class beta_t>
111 constexpr struct starpu_codelet gen_cl_herk() noexcept
114 constexpr bool use_cublas =
115 cuda::is_cublas_v<TA, TC, alpha_t, beta_t>;
117 cl.cpu_funcs[0] = func::herk<TA, TC, alpha_t, beta_t>;
118 if constexpr (use_cublas) {
119 cl.cuda_funcs[0] = func::herk<TA, TC, alpha_t, beta_t, 1>;
120 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
123 cl.modes[0] = STARPU_R;
124 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
125 cl.name =
"tlapack::starpu::herk";
129 cl.where |= STARPU_CPU;
130 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
136 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
137 constexpr struct starpu_codelet gen_cl_syr2k() noexcept
141 cl.cpu_funcs[0] = func::syr2k<TA, TB, TC, alpha_t, beta_t>;
143 cl.modes[0] = STARPU_R;
144 cl.modes[1] = STARPU_R;
145 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
146 cl.name =
"tlapack::starpu::syr2k";
150 cl.where |= STARPU_CPU;
156 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
157 constexpr struct starpu_codelet gen_cl_her2k() noexcept
161 cl.cpu_funcs[0] = func::her2k<TA, TB, TC, alpha_t, beta_t>;
163 cl.modes[0] = STARPU_R;
164 cl.modes[1] = STARPU_R;
165 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
166 cl.name =
"tlapack::starpu::her2k";
170 cl.where |= STARPU_CPU;
176 template <
class TA,
class TB,
class alpha_t>
177 constexpr struct starpu_codelet gen_cl_trmm() noexcept
181 cl.cpu_funcs[0] = func::trmm<TA, TB, alpha_t>;
183 cl.modes[0] = STARPU_R;
184 cl.modes[1] = STARPU_RW;
185 cl.name =
"tlapack::starpu::trmm";
189 cl.where |= STARPU_CPU;
195 template <
class TA,
class TB,
class alpha_t>
196 constexpr struct starpu_codelet gen_cl_trsm() noexcept
199 constexpr bool use_cublas = cuda::is_cublas_v<TA, TB, alpha_t>;
201 cl.cpu_funcs[0] = func::trsm<TA, TB, alpha_t>;
202 if constexpr (use_cublas) {
203 cl.cuda_funcs[0] = func::trsm<TA, TB, alpha_t, 1>;
204 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
207 cl.modes[0] = STARPU_R;
208 cl.modes[1] = STARPU_RW;
209 cl.name =
"tlapack::starpu::trsm";
213 cl.where |= STARPU_CPU;
214 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
223 template <
class uplo_t,
class T,
bool has_info>
224 constexpr struct starpu_codelet gen_cl_potrf() noexcept
227 constexpr bool use_cusolver = cuda::is_cusolver_v<T>;
229 cl.cpu_funcs[0] = func::potrf<uplo_t, T, has_info>;
230 if constexpr (use_cusolver) {
231 cl.cuda_funcs[0] = func::potrf<uplo_t, T, has_info, 1>;
232 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
233 cl.nbuffers = 2 + (has_info ? 1 : 0);
234 cl.modes[1 + (has_info ? 1 : 0)] = starpu_data_access_mode(
235 (
int)STARPU_SCRATCH | (int)STARPU_NOFOOTPRINT);
238 cl.nbuffers = 1 + (has_info ? 1 : 0);
240 cl.modes[0] = STARPU_RW;
241 if constexpr (has_info) cl.modes[1] = STARPU_W;
242 cl.name =
"tlapack::starpu::potrf";
246 cl.where |= STARPU_CPU;
247 if constexpr (use_cusolver) cl.where |= STARPU_CUDA;
259 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
260 constexpr const struct starpu_codelet gemm =
261 internal::gen_cl_gemm<TA, TB, TC, alpha_t, beta_t>();
263 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
264 constexpr const struct starpu_codelet symm =
265 internal::gen_cl_symm<TA, TB, TC, alpha_t, beta_t>();
267 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
268 constexpr const struct starpu_codelet hemm =
269 internal::gen_cl_hemm<TA, TB, TC, alpha_t, beta_t>();
271 template <
class TA,
class TC,
class alpha_t,
class beta_t>
272 constexpr const struct starpu_codelet syrk =
273 internal::gen_cl_syrk<TA, TC, alpha_t, beta_t>();
275 template <
class TA,
class TC,
class alpha_t,
class beta_t>
276 constexpr const struct starpu_codelet herk =
277 internal::gen_cl_herk<TA, TC, alpha_t, beta_t>();
279 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
280 constexpr const struct starpu_codelet syr2k =
281 internal::gen_cl_syr2k<TA, TB, TC, alpha_t, beta_t>();
283 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
284 constexpr const struct starpu_codelet her2k =
285 internal::gen_cl_her2k<TA, TB, TC, alpha_t, beta_t>();
287 template <
class TA,
class TB,
class alpha_t>
288 constexpr const struct starpu_codelet trmm =
289 internal::gen_cl_trmm<TA, TB, alpha_t>();
291 template <
class TA,
class TB,
class alpha_t>
292 constexpr const struct starpu_codelet trsm =
293 internal::gen_cl_trsm<TA, TB, alpha_t>();
295 template <
class uplo_t,
class T>
296 constexpr const struct starpu_codelet potrf =
297 internal::gen_cl_potrf<uplo_t, T, true>();
299 template <
class uplo_t,
class T>
300 constexpr const struct starpu_codelet potrf_noinfo =
301 internal::gen_cl_potrf<uplo_t, T, false>();
constexpr struct starpu_codelet codelet_init() noexcept
Return an empty starpu_codelet struct.
Definition MatrixEntry.hpp:34
StarPU functions for BLAS and LAPACK tasks.