<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trsm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_STARPU_TRSM_HH
11#define TLAPACK_STARPU_TRSM_HH
12
17
18namespace tlapack {
19
21template <class TA, class TB, class alpha_t>
23 Uplo uplo,
24 Op trans,
25 Diag diag,
26 const alpha_t& alpha,
27 const starpu::Matrix<TA>& A,
29{
30 using starpu::idx_t;
31
32 // constants
33 const TB one(1);
34 const idx_t m = B.nrows();
35 const idx_t n = B.ncols();
36 const idx_t nx = B.get_nx();
37 const idx_t ny = B.get_ny();
38
39 // quick return
40 if (m == 0 || n == 0) return;
41
42 // check arguments
43 tlapack_check_false(side != Side::Left && side != Side::Right);
44 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper);
45 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
46 trans != Op::ConjTrans);
47 tlapack_check_false(diag != Diag::NonUnit && diag != Diag::Unit);
48 tlapack_check(A.nrows() == A.ncols());
49 tlapack_check(A.nrows() == (side == Side::Left ? m : n));
50 tlapack_check(A.get_nx() == A.get_ny());
51 tlapack_check(A.get_nx() == (side == Side::Left ? nx : ny));
52
53 // Remove const type from A
54 auto& A_ = const_cast<starpu::Matrix<TA>&>(A);
55
56 if (side == Side::Left) {
57 if (trans == Op::NoTrans) {
58 if (uplo == Uplo::Upper) {
59 for (idx_t ix = 0; ix < nx; ++ix) {
60 for (idx_t iy = 0; iy < ny; ++iy) {
61 starpu::insert_task_trsm<TA, TB>(
62 side, uplo, trans, diag, ((ix == 0) ? alpha : one),
63 A_.tile(nx - ix - 1, nx - ix - 1),
64 B.tile(nx - ix - 1, iy));
65 }
66 auto C = B.get_tiles(0, 0, nx - ix - 1, ny);
68 A.get_const_tiles(0, nx - ix - 1, nx - ix - 1, 1),
69 B.get_const_tiles(nx - ix - 1, 0, 1, ny),
70 ((ix == 0) ? alpha : one), C);
71 }
72 }
73 else { // uplo == Uplo::Lower
74 for (idx_t ix = 0; ix < nx; ++ix) {
75 for (idx_t iy = 0; iy < ny; ++iy) {
76 starpu::insert_task_trsm<TA, TB>(
77 side, uplo, trans, diag, ((ix == 0) ? alpha : one),
78 A_.tile(ix, ix), B.tile(ix, iy));
79 }
80 auto C = B.get_tiles(ix + 1, 0, nx - ix - 1, ny);
82 A.get_const_tiles(ix + 1, ix, nx - ix - 1, 1),
83 B.get_const_tiles(ix, 0, 1, ny),
84 ((ix == 0) ? alpha : one), C);
85 }
86 }
87 }
88 else { // trans == Op::Trans or Op::ConjTrans
89 if (uplo == Uplo::Upper) {
90 for (idx_t ix = 0; ix < nx; ++ix) {
91 // auto Aii = A.get_const_tiles(ix, ix, 1, 1);
92 for (idx_t iy = 0; iy < ny; ++iy) {
93 starpu::insert_task_trsm<TA, TB>(
94 side, uplo, trans, diag, ((ix == 0) ? alpha : one),
95 A_.tile(ix, ix), B.tile(ix, iy));
96 // auto B_ = B.get_tiles(ix, iy, 1, 1);
97 // trsm<starpu::Matrix<const TA>>(
98 // side, uplo, trans, diag, ((ix == 0) ? alpha :
99 // ((ix == 0) ? alpha : one)), Aii, B_);
100 }
101 auto C = B.get_tiles(ix + 1, 0, nx - ix - 1, ny);
103 A.get_const_tiles(ix, ix + 1, 1, nx - ix - 1),
104 B.get_const_tiles(ix, 0, 1, ny),
105 ((ix == 0) ? alpha : one), C);
106 }
107 }
108 else { // uplo == Uplo::Lower
109 for (idx_t ix = 0; ix < nx; ++ix) {
110 for (idx_t iy = 0; iy < ny; ++iy) {
111 starpu::insert_task_trsm<TA, TB>(
112 side, uplo, trans, diag, ((ix == 0) ? alpha : one),
113 A_.tile(nx - ix - 1, nx - ix - 1),
114 B.tile(nx - ix - 1, iy));
115 }
116 auto C = B.get_tiles(0, 0, nx - ix - 1, ny);
118 A.get_const_tiles(nx - ix - 1, 0, 1, nx - ix - 1),
119 B.get_const_tiles(nx - ix - 1, 0, 1, ny),
120 ((ix == 0) ? alpha : one), C);
121 }
122 }
123 }
124 }
125 else { // side == Side::Right
126 if (trans == Op::NoTrans) {
127 if (uplo == Uplo::Upper) {
128 for (idx_t iy = 0; iy < ny; ++iy) {
129 for (idx_t ix = 0; ix < nx; ++ix) {
130 starpu::insert_task_trsm<TA, TB>(
131 side, uplo, trans, diag, ((iy == 0) ? alpha : one),
132 A_.tile(iy, iy), B.tile(ix, iy));
133 }
134 auto C = B.get_tiles(0, iy + 1, nx, ny - iy - 1);
136 B.get_const_tiles(0, iy, nx, 1),
137 A.get_const_tiles(iy, iy + 1, 1, ny - iy - 1),
138 ((iy == 0) ? alpha : one), C);
139 }
140 }
141 else { // uplo == Uplo::Lower
142 for (idx_t iy = 0; iy < ny; ++iy) {
143 for (idx_t ix = 0; ix < nx; ++ix) {
144 starpu::insert_task_trsm<TA, TB>(
145 side, uplo, trans, diag, ((iy == 0) ? alpha : one),
146 A_.tile(ny - iy - 1, ny - iy - 1),
147 B.tile(ix, ny - iy - 1));
148 }
149 auto C = B.get_tiles(0, 0, nx, ny - iy - 1);
151 B.get_const_tiles(0, ny - iy - 1, nx, 1),
152 A.get_const_tiles(ny - iy - 1, 0, 1, ny - iy - 1),
153 ((iy == 0) ? alpha : one), C);
154 }
155 }
156 }
157 else { // trans == Op::Trans or Op::ConjTrans
158 if (uplo == Uplo::Upper) {
159 for (idx_t iy = 0; iy < ny; ++iy) {
160 for (idx_t ix = 0; ix < nx; ++ix) {
161 starpu::insert_task_trsm<TA, TB>(
162 side, uplo, trans, diag, ((iy == 0) ? alpha : one),
163 A_.tile(ny - iy - 1, ny - iy - 1),
164 B.tile(ix, ny - iy - 1));
165 }
166 auto C = B.get_tiles(0, 0, nx, ny - iy - 1);
168 B.get_const_tiles(0, ny - iy - 1, nx, 1),
169 A.get_const_tiles(0, ny - iy - 1, ny - iy - 1, 1),
170 ((iy == 0) ? alpha : one), C);
171 }
172 } // uplo == Uplo::Lower
173 else {
174 for (idx_t iy = 0; iy < ny; ++iy) {
175 for (idx_t ix = 0; ix < nx; ++ix) {
176 starpu::insert_task_trsm<TA, TB>(
177 side, uplo, trans, diag, ((iy == 0) ? alpha : one),
178 A_.tile(iy, iy), B.tile(ix, iy));
179 }
180 auto C = B.get_tiles(0, iy + 1, nx, ny - iy - 1);
182 B.get_const_tiles(0, iy, nx, 1),
183 A.get_const_tiles(iy + 1, iy, ny - iy - 1, 1),
184 ((iy == 0) ? alpha : one), C);
185 }
186 }
187 }
188 }
189}
190
191} // namespace tlapack
192
193#endif // TLAPACK_STARPU_TRSM_HH
Diag
Definition types.hpp:192
Side
Definition types.hpp:266
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
Class for representing a matrix in StarPU that is split into tiles.
Definition Matrix.hpp:133
Matrix< T > get_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) noexcept
Create a submatrix from a list of tiles.
Definition Matrix.hpp:298
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void trsm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Solve the triangular matrix-vector equation.
Definition trsm.hpp:76
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Task insertion functions.