11#ifndef TLAPACK_STARPU_CODELETS_HH
12#define TLAPACK_STARPU_CODELETS_HH
24 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
28 constexpr bool use_cublas =
29 cuda::is_cublas_v<TA, TB, TC, alpha_t, beta_t>;
31 cl.cpu_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t>;
32 if constexpr (use_cublas) {
33 cl.cuda_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t, 1>;
34 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
37 cl.modes[0] = STARPU_R;
38 cl.modes[1] = STARPU_R;
39 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
40 cl.name =
"tlapack::starpu::gemm";
44 cl.where |= STARPU_CPU;
45 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
51 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
56 cl.cpu_funcs[0] = func::symm<TA, TB, TC, alpha_t, beta_t>;
58 cl.modes[0] = STARPU_R;
59 cl.modes[1] = STARPU_R;
60 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
61 cl.name =
"tlapack::starpu::symm";
65 cl.where |= STARPU_CPU;
71 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
76 cl.cpu_funcs[0] = func::hemm<TA, TB, TC, alpha_t, beta_t>;
78 cl.modes[0] = STARPU_R;
79 cl.modes[1] = STARPU_R;
80 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
81 cl.name =
"tlapack::starpu::hemm";
85 cl.where |= STARPU_CPU;
91 template <
class TA,
class TC,
class alpha_t,
class beta_t>
96 cl.cpu_funcs[0] = func::syrk<TA, TC, alpha_t, beta_t>;
98 cl.modes[0] = STARPU_R;
99 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
100 cl.name =
"tlapack::starpu::syrk";
104 cl.where |= STARPU_CPU;
110 template <
class TA,
class TC,
class alpha_t,
class beta_t>
114 constexpr bool use_cublas =
115 cuda::is_cublas_v<TA, TC, alpha_t, beta_t>;
117 cl.cpu_funcs[0] = func::herk<TA, TC, alpha_t, beta_t>;
118 if constexpr (use_cublas) {
119 cl.cuda_funcs[0] = func::herk<TA, TC, alpha_t, beta_t, 1>;
120 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
123 cl.modes[0] = STARPU_R;
124 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
125 cl.name =
"tlapack::starpu::herk";
129 cl.where |= STARPU_CPU;
130 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
136 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
141 cl.cpu_funcs[0] = func::syr2k<TA, TB, TC, alpha_t, beta_t>;
143 cl.modes[0] = STARPU_R;
144 cl.modes[1] = STARPU_R;
145 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
146 cl.name =
"tlapack::starpu::syr2k";
150 cl.where |= STARPU_CPU;
156 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
161 cl.cpu_funcs[0] = func::her2k<TA, TB, TC, alpha_t, beta_t>;
163 cl.modes[0] = STARPU_R;
164 cl.modes[1] = STARPU_R;
165 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
166 cl.name =
"tlapack::starpu::her2k";
170 cl.where |= STARPU_CPU;
176 template <
class TA,
class TB,
class alpha_t>
181 cl.cpu_funcs[0] = func::trmm<TA, TB, alpha_t>;
183 cl.modes[0] = STARPU_R;
184 cl.modes[1] = STARPU_RW;
185 cl.name =
"tlapack::starpu::trmm";
189 cl.where |= STARPU_CPU;
195 template <
class TA,
class TB,
class alpha_t>
199 constexpr bool use_cublas = cuda::is_cublas_v<TA, TB, alpha_t>;
201 cl.cpu_funcs[0] = func::trsm<TA, TB, alpha_t>;
202 if constexpr (use_cublas) {
203 cl.cuda_funcs[0] = func::trsm<TA, TB, alpha_t, 1>;
204 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
207 cl.modes[0] = STARPU_R;
208 cl.modes[1] = STARPU_RW;
209 cl.name =
"tlapack::starpu::trsm";
213 cl.where |= STARPU_CPU;
214 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
223 template <
class uplo_t,
class T,
bool has_info>
227 constexpr bool use_cusolver = cuda::is_cusolver_v<T>;
229 cl.cpu_funcs[0] = func::potrf<uplo_t, T, has_info>;
230 if constexpr (use_cusolver) {
231 cl.cuda_funcs[0] = func::potrf<uplo_t, T, has_info, 1>;
232 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
233 cl.nbuffers = 2 + (has_info ? 1 : 0);
234 cl.modes[1 + (has_info ? 1 : 0)] = starpu_data_access_mode(
235 (
int)STARPU_SCRATCH | (int)STARPU_NOFOOTPRINT);
238 cl.nbuffers = 1 + (has_info ? 1 : 0);
240 cl.modes[0] = STARPU_RW;
241 if constexpr (has_info) cl.modes[1] = STARPU_W;
242 cl.name =
"tlapack::starpu::potrf";
246 cl.where |= STARPU_CPU;
247 if constexpr (use_cusolver) cl.where |= STARPU_CUDA;
259 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
261 internal::gen_cl_gemm<TA, TB, TC, alpha_t, beta_t>();
263 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
265 internal::gen_cl_symm<TA, TB, TC, alpha_t, beta_t>();
267 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
269 internal::gen_cl_hemm<TA, TB, TC, alpha_t, beta_t>();
271 template <
class TA,
class TC,
class alpha_t,
class beta_t>
273 internal::gen_cl_syrk<TA, TC, alpha_t, beta_t>();
275 template <
class TA,
class TC,
class alpha_t,
class beta_t>
277 internal::gen_cl_herk<TA, TC, alpha_t, beta_t>();
279 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
281 internal::gen_cl_syr2k<TA, TB, TC, alpha_t, beta_t>();
283 template <
class TA,
class TB,
class TC,
class alpha_t,
class beta_t>
285 internal::gen_cl_her2k<TA, TB, TC, alpha_t, beta_t>();
287 template <
class TA,
class TB,
class alpha_t>
289 internal::gen_cl_trmm<TA, TB, alpha_t>();
291 template <
class TA,
class TB,
class alpha_t>
293 internal::gen_cl_trsm<TA, TB, alpha_t>();
295 template <
class uplo_t,
class T>
297 internal::gen_cl_potrf<uplo_t, T, true>();
299 template <
class uplo_t,
class T>
301 internal::gen_cl_potrf<uplo_t, T, false>();
constexpr struct starpu_codelet codelet_init() noexcept
Return an empty starpu_codelet struct.
Definition MatrixEntry.hpp:34
StarPU functions for BLAS and LAPACK tasks.
Sort the numbers in D in increasing order (if ID = 'I') or in decreasing order (if ID = 'D' ).
Definition arrayTraits.hpp:15
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113