<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
herk.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_STARPU_HERK_HH
11#define TLAPACK_STARPU_HERK_HH
12
17
18namespace tlapack {
19
21template <class TA, class TC, class alpha_t, class beta_t>
23 Op trans,
24 const alpha_t& alpha,
25 const starpu::Matrix<TA>& A,
26 const beta_t& beta,
28
29{
30 using starpu::idx_t;
31
32 // constants
33 const real_type<TC> one(1);
34 const idx_t n = (trans == Op::NoTrans) ? A.nrows() : A.ncols();
35 const idx_t k = (trans == Op::NoTrans) ? A.ncols() : A.nrows();
36 const idx_t nx = (trans == Op::NoTrans) ? A.get_nx() : A.get_ny();
37 const idx_t ny = (trans == Op::NoTrans) ? A.get_ny() : A.get_nx();
38
39 // quick return
40 if (n == 0) return;
41 if (k == 0 && beta == one) return;
42
43 // check arguments
44 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper &&
45 uplo != Uplo::General);
46 tlapack_check_false(trans != Op::NoTrans && trans != Op::ConjTrans);
47 tlapack_check_false(C.nrows() != n);
48 tlapack_check_false(C.ncols() != n);
49 tlapack_check_false(C.get_nx() != nx);
50 tlapack_check_false(C.get_ny() != nx);
51
52 // Remove const type from A and B
53 auto& A_ = const_cast<starpu::Matrix<TA>&>(A);
54
55 if (trans == Op::NoTrans) {
56 for (idx_t ix = 0; ix < nx; ++ix) {
57 // Update diagonal tile of C
58 starpu::insert_task_herk<TA, TC>(uplo, trans, alpha, A_.tile(ix, 0),
59 beta, C.tile(ix, ix));
60 for (idx_t iy = 1; iy < ny; ++iy)
61 starpu::insert_task_herk<TA, TC>(
62 uplo, trans, alpha, A_.tile(ix, iy), one, C.tile(ix, ix));
63
64 // Update off-diagonal tiles of C
65 auto Ai = A.get_const_tiles(ix, 0, 1, ny);
66 auto Bi = A.get_const_tiles(ix + 1, 0, nx - ix - 1, ny);
67 if (uplo == Uplo::Upper || uplo == Uplo::General) {
68 auto Ci = C.get_tiles(ix, ix + 1, 1, nx - ix - 1);
70 }
71 if (uplo == Uplo::Lower || uplo == Uplo::General) {
72 auto Ci = C.get_tiles(ix + 1, ix, nx - ix - 1, 1);
74 }
75 }
76 }
77 else { // trans == Op::ConjTrans
78 for (idx_t ix = 0; ix < nx; ++ix) {
79 // Update diagonal tile of C
80 starpu::insert_task_herk<TA, TC>(uplo, trans, alpha, A_.tile(0, ix),
81 beta, C.tile(ix, ix));
82 for (idx_t iy = 1; iy < ny; ++iy)
83 starpu::insert_task_herk<TA, TC>(
84 uplo, trans, alpha, A_.tile(iy, ix), one, C.tile(ix, ix));
85
86 // Update off-diagonal tiles of C
87 auto Ai = A.get_const_tiles(0, ix, ny, 1);
88 auto Bi = A.get_const_tiles(0, ix + 1, ny, nx - ix - 1);
89 if (uplo == Uplo::Upper || uplo == Uplo::General) {
90 auto Ci = C.get_tiles(ix, ix + 1, 1, nx - ix - 1);
92 }
93 if (uplo == Uplo::Lower || uplo == Uplo::General) {
94 auto Ci = C.get_tiles(ix + 1, ix, nx - ix - 1, 1);
96 }
97 }
98 }
99}
100} // namespace tlapack
101
102#endif // TLAPACK_STARPU_HERK_HH
Op
Definition types.hpp:222
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
Uplo
Definition types.hpp:45
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
Class for representing a matrix in StarPU that is split into tiles.
Definition Matrix.hpp:133
void herk(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition herk.hpp:68
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Task insertion functions.