<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
codelets.hpp
Go to the documentation of this file.
1
4//
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_STARPU_CODELETS_HH
12#define TLAPACK_STARPU_CODELETS_HH
13
16
17namespace tlapack {
18namespace starpu {
19 namespace internal {
20
21 // ---------------------------------------------------------------------
22 // Functions to generate codelets for BLAS routines
23
24 template <class TA, class TB, class TC, class alpha_t, class beta_t>
25 constexpr struct starpu_codelet gen_cl_gemm() noexcept
26 {
27 struct starpu_codelet cl = codelet_init();
28 constexpr bool use_cublas =
29 cuda::is_cublas_v<TA, TB, TC, alpha_t, beta_t>;
30
31 cl.cpu_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t>;
32 if constexpr (use_cublas) {
33 cl.cuda_funcs[0] = func::gemm<TA, TB, TC, alpha_t, beta_t, 1>;
34 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
35 }
36 cl.nbuffers = 3;
37 cl.modes[0] = STARPU_R;
38 cl.modes[1] = STARPU_R;
39 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
40 cl.name = "tlapack::starpu::gemm";
41
42 // The following lines are needed to make the codelet const
43 // See _starpu_codelet_check_deprecated_fields() in StarPU:
44 cl.where |= STARPU_CPU;
45 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
46 cl.checked = 1;
47
48 return cl;
49 }
50
51 template <class TA, class TB, class TC, class alpha_t, class beta_t>
52 constexpr struct starpu_codelet gen_cl_symm() noexcept
53 {
54 struct starpu_codelet cl = codelet_init();
55
56 cl.cpu_funcs[0] = func::symm<TA, TB, TC, alpha_t, beta_t>;
57 cl.nbuffers = 3;
58 cl.modes[0] = STARPU_R;
59 cl.modes[1] = STARPU_R;
60 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
61 cl.name = "tlapack::starpu::symm";
62
63 // The following lines are needed to make the codelet const
64 // See _starpu_codelet_check_deprecated_fields() in StarPU:
65 cl.where |= STARPU_CPU;
66 cl.checked = 1;
67
68 return cl;
69 }
70
71 template <class TA, class TB, class TC, class alpha_t, class beta_t>
72 constexpr struct starpu_codelet gen_cl_hemm() noexcept
73 {
74 struct starpu_codelet cl = codelet_init();
75
76 cl.cpu_funcs[0] = func::hemm<TA, TB, TC, alpha_t, beta_t>;
77 cl.nbuffers = 3;
78 cl.modes[0] = STARPU_R;
79 cl.modes[1] = STARPU_R;
80 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
81 cl.name = "tlapack::starpu::hemm";
82
83 // The following lines are needed to make the codelet const
84 // See _starpu_codelet_check_deprecated_fields() in StarPU:
85 cl.where |= STARPU_CPU;
86 cl.checked = 1;
87
88 return cl;
89 }
90
91 template <class TA, class TC, class alpha_t, class beta_t>
92 constexpr struct starpu_codelet gen_cl_syrk() noexcept
93 {
94 struct starpu_codelet cl = codelet_init();
95
96 cl.cpu_funcs[0] = func::syrk<TA, TC, alpha_t, beta_t>;
97 cl.nbuffers = 2;
98 cl.modes[0] = STARPU_R;
99 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
100 cl.name = "tlapack::starpu::syrk";
101
102 // The following lines are needed to make the codelet const
103 // See _starpu_codelet_check_deprecated_fields() in StarPU:
104 cl.where |= STARPU_CPU;
105 cl.checked = 1;
106
107 return cl;
108 }
109
110 template <class TA, class TC, class alpha_t, class beta_t>
111 constexpr struct starpu_codelet gen_cl_herk() noexcept
112 {
113 struct starpu_codelet cl = codelet_init();
114 constexpr bool use_cublas =
115 cuda::is_cublas_v<TA, TC, alpha_t, beta_t>;
116
117 cl.cpu_funcs[0] = func::herk<TA, TC, alpha_t, beta_t>;
118 if constexpr (use_cublas) {
119 cl.cuda_funcs[0] = func::herk<TA, TC, alpha_t, beta_t, 1>;
120 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
121 }
122 cl.nbuffers = 2;
123 cl.modes[0] = STARPU_R;
124 cl.modes[1] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
125 cl.name = "tlapack::starpu::herk";
126
127 // The following lines are needed to make the codelet const
128 // See _starpu_codelet_check_deprecated_fields() in StarPU:
129 cl.where |= STARPU_CPU;
130 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
131 cl.checked = 1;
132
133 return cl;
134 }
135
136 template <class TA, class TB, class TC, class alpha_t, class beta_t>
137 constexpr struct starpu_codelet gen_cl_syr2k() noexcept
138 {
139 struct starpu_codelet cl = codelet_init();
140
141 cl.cpu_funcs[0] = func::syr2k<TA, TB, TC, alpha_t, beta_t>;
142 cl.nbuffers = 3;
143 cl.modes[0] = STARPU_R;
144 cl.modes[1] = STARPU_R;
145 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
146 cl.name = "tlapack::starpu::syr2k";
147
148 // The following lines are needed to make the codelet const
149 // See _starpu_codelet_check_deprecated_fields() in StarPU:
150 cl.where |= STARPU_CPU;
151 cl.checked = 1;
152
153 return cl;
154 }
155
156 template <class TA, class TB, class TC, class alpha_t, class beta_t>
157 constexpr struct starpu_codelet gen_cl_her2k() noexcept
158 {
159 struct starpu_codelet cl = codelet_init();
160
161 cl.cpu_funcs[0] = func::her2k<TA, TB, TC, alpha_t, beta_t>;
162 cl.nbuffers = 3;
163 cl.modes[0] = STARPU_R;
164 cl.modes[1] = STARPU_R;
165 cl.modes[2] = is_same_v<beta_t, StrongZero> ? STARPU_W : STARPU_RW;
166 cl.name = "tlapack::starpu::her2k";
167
168 // The following lines are needed to make the codelet const
169 // See _starpu_codelet_check_deprecated_fields() in StarPU:
170 cl.where |= STARPU_CPU;
171 cl.checked = 1;
172
173 return cl;
174 }
175
176 template <class TA, class TB, class alpha_t>
177 constexpr struct starpu_codelet gen_cl_trmm() noexcept
178 {
179 struct starpu_codelet cl = codelet_init();
180
181 cl.cpu_funcs[0] = func::trmm<TA, TB, alpha_t>;
182 cl.nbuffers = 2;
183 cl.modes[0] = STARPU_R;
184 cl.modes[1] = STARPU_RW;
185 cl.name = "tlapack::starpu::trmm";
186
187 // The following lines are needed to make the codelet const
188 // See _starpu_codelet_check_deprecated_fields() in StarPU:
189 cl.where |= STARPU_CPU;
190 cl.checked = 1;
191
192 return cl;
193 }
194
195 template <class TA, class TB, class alpha_t>
196 constexpr struct starpu_codelet gen_cl_trsm() noexcept
197 {
198 struct starpu_codelet cl = codelet_init();
199 constexpr bool use_cublas = cuda::is_cublas_v<TA, TB, alpha_t>;
200
201 cl.cpu_funcs[0] = func::trsm<TA, TB, alpha_t>;
202 if constexpr (use_cublas) {
203 cl.cuda_funcs[0] = func::trsm<TA, TB, alpha_t, 1>;
204 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
205 }
206 cl.nbuffers = 2;
207 cl.modes[0] = STARPU_R;
208 cl.modes[1] = STARPU_RW;
209 cl.name = "tlapack::starpu::trsm";
210
211 // The following lines are needed to make the codelet const
212 // See _starpu_codelet_check_deprecated_fields() in StarPU:
213 cl.where |= STARPU_CPU;
214 if constexpr (use_cublas) cl.where |= STARPU_CUDA;
215 cl.checked = 1;
216
217 return cl;
218 }
219
220 // ---------------------------------------------------------------------
221 // Functions to generate codelets for LAPACK routines
222
223 template <class uplo_t, class T, bool has_info>
224 constexpr struct starpu_codelet gen_cl_potrf() noexcept
225 {
226 struct starpu_codelet cl = codelet_init();
227 constexpr bool use_cusolver = cuda::is_cusolver_v<T>;
228
229 cl.cpu_funcs[0] = func::potrf<uplo_t, T, has_info>;
230 if constexpr (use_cusolver) {
231 cl.cuda_funcs[0] = func::potrf<uplo_t, T, has_info, 1>;
232 cl.cuda_flags[0] = STARPU_CUDA_ASYNC;
233 cl.nbuffers = 2 + (has_info ? 1 : 0);
234 cl.modes[1 + (has_info ? 1 : 0)] = starpu_data_access_mode(
235 (int)STARPU_SCRATCH | (int)STARPU_NOFOOTPRINT);
236 }
237 else {
238 cl.nbuffers = 1 + (has_info ? 1 : 0);
239 }
240 cl.modes[0] = STARPU_RW;
241 if constexpr (has_info) cl.modes[1] = STARPU_W;
242 cl.name = "tlapack::starpu::potrf";
243
244 // The following lines are needed to make the codelet const
245 // See _starpu_codelet_check_deprecated_fields() in StarPU:
246 cl.where |= STARPU_CPU;
247 if constexpr (use_cusolver) cl.where |= STARPU_CUDA;
248 cl.checked = 1;
249
250 return cl;
251 }
252 } // namespace internal
253
254 // ---------------------------------------------------------------------
255 // Codelets
256
257 namespace cl {
258
259 template <class TA, class TB, class TC, class alpha_t, class beta_t>
260 constexpr const struct starpu_codelet gemm =
261 internal::gen_cl_gemm<TA, TB, TC, alpha_t, beta_t>();
262
263 template <class TA, class TB, class TC, class alpha_t, class beta_t>
264 constexpr const struct starpu_codelet symm =
265 internal::gen_cl_symm<TA, TB, TC, alpha_t, beta_t>();
266
267 template <class TA, class TB, class TC, class alpha_t, class beta_t>
268 constexpr const struct starpu_codelet hemm =
269 internal::gen_cl_hemm<TA, TB, TC, alpha_t, beta_t>();
270
271 template <class TA, class TC, class alpha_t, class beta_t>
272 constexpr const struct starpu_codelet syrk =
273 internal::gen_cl_syrk<TA, TC, alpha_t, beta_t>();
274
275 template <class TA, class TC, class alpha_t, class beta_t>
276 constexpr const struct starpu_codelet herk =
277 internal::gen_cl_herk<TA, TC, alpha_t, beta_t>();
278
279 template <class TA, class TB, class TC, class alpha_t, class beta_t>
280 constexpr const struct starpu_codelet syr2k =
281 internal::gen_cl_syr2k<TA, TB, TC, alpha_t, beta_t>();
282
283 template <class TA, class TB, class TC, class alpha_t, class beta_t>
284 constexpr const struct starpu_codelet her2k =
285 internal::gen_cl_her2k<TA, TB, TC, alpha_t, beta_t>();
286
287 template <class TA, class TB, class alpha_t>
288 constexpr const struct starpu_codelet trmm =
289 internal::gen_cl_trmm<TA, TB, alpha_t>();
290
291 template <class TA, class TB, class alpha_t>
292 constexpr const struct starpu_codelet trsm =
293 internal::gen_cl_trsm<TA, TB, alpha_t>();
294
295 template <class uplo_t, class T>
296 constexpr const struct starpu_codelet potrf =
297 internal::gen_cl_potrf<uplo_t, T, true>();
298
299 template <class uplo_t, class T>
300 constexpr const struct starpu_codelet potrf_noinfo =
301 internal::gen_cl_potrf<uplo_t, T, false>();
302
303 } // namespace cl
304
305} // namespace starpu
306
307} // namespace tlapack
308
309#endif // TLAPACK_STARPU_CODELETS_HH
constexpr struct starpu_codelet codelet_init() noexcept
Return an empty starpu_codelet struct.
Definition MatrixEntry.hpp:34
StarPU functions for BLAS and LAPACK tasks.