<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trsm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_TRSM_HH
12#define TLAPACK_BLAS_TRSM_HH
13
15
16namespace tlapack {
17
69template <TLAPACK_MATRIX matrixA_t,
70 TLAPACK_MATRIX matrixB_t,
71 TLAPACK_SCALAR alpha_t,
72 class T = type_t<matrixB_t>,
73 disable_if_allow_optblas_t<pair<matrixA_t, T>,
74 pair<matrixB_t, T>,
75 pair<alpha_t, T> > = 0>
77 Uplo uplo,
78 Op trans,
79 Diag diag,
80 const alpha_t& alpha,
81 const matrixA_t& A,
82 matrixB_t& B)
83{
84 // data traits
85 using idx_t = size_type<matrixA_t>;
86 using TB = type_t<matrixB_t>;
87
88 // constants
89 const idx_t m = nrows(B);
90 const idx_t n = ncols(B);
91
92 // check arguments
93 tlapack_check_false(side != Side::Left && side != Side::Right);
94 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper);
95 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
96 trans != Op::ConjTrans);
97 tlapack_check_false(diag != Diag::NonUnit && diag != Diag::Unit);
98 tlapack_check_false(nrows(A) != ncols(A));
99 tlapack_check_false(nrows(A) != ((side == Side::Left) ? m : n));
100
101 if (side == Side::Left) {
103 if (trans == Op::NoTrans) {
104 if (uplo == Uplo::Upper) {
105 for (idx_t j = 0; j < n; ++j) {
106 for (idx_t i = 0; i < m; ++i)
107 B(i, j) *= alpha;
108 for (idx_t k = m - 1; k != idx_t(-1); --k) {
109 if (diag == Diag::NonUnit) B(k, j) /= A(k, k);
110 for (idx_t i = 0; i < k; ++i)
111 B(i, j) -= A(i, k) * B(k, j);
112 }
113 }
114 }
115 else { // uplo == Uplo::Lower
116 for (idx_t j = 0; j < n; ++j) {
117 for (idx_t i = 0; i < m; ++i)
118 B(i, j) *= alpha;
119 for (idx_t k = 0; k < m; ++k) {
120 if (diag == Diag::NonUnit) B(k, j) /= A(k, k);
121 for (idx_t i = k + 1; i < m; ++i)
122 B(i, j) -= A(i, k) * B(k, j);
123 }
124 }
125 }
126 }
127 else if (trans == Op::Trans) {
128 if (uplo == Uplo::Upper) {
129 for (idx_t j = 0; j < n; ++j) {
130 for (idx_t i = 0; i < m; ++i) {
131 scalar_t sum = alpha * B(i, j);
132 for (idx_t k = 0; k < i; ++k)
133 sum -= A(k, i) * B(k, j);
134 B(i, j) = (diag == Diag::NonUnit) ? sum / A(i, i) : sum;
135 }
136 }
137 }
138 else { // uplo == Uplo::Lower
139 for (idx_t j = 0; j < n; ++j) {
140 for (idx_t i = m - 1; i != idx_t(-1); --i) {
141 scalar_t sum = alpha * B(i, j);
142 for (idx_t k = i + 1; k < m; ++k)
143 sum -= A(k, i) * B(k, j);
144 B(i, j) = (diag == Diag::NonUnit) ? sum / A(i, i) : sum;
145 }
146 }
147 }
148 }
149 else { // trans == Op::ConjTrans
150 if (uplo == Uplo::Upper) {
151 for (idx_t j = 0; j < n; ++j) {
152 for (idx_t i = 0; i < m; ++i) {
153 scalar_t sum = alpha * B(i, j);
154 for (idx_t k = 0; k < i; ++k)
155 sum -= conj(A(k, i)) * B(k, j);
156 B(i, j) =
157 (diag == Diag::NonUnit) ? sum / conj(A(i, i)) : sum;
158 }
159 }
160 }
161 else { // uplo == Uplo::Lower
162 for (idx_t j = 0; j < n; ++j) {
163 for (idx_t i = m - 1; i != idx_t(-1); --i) {
164 scalar_t sum = alpha * B(i, j);
165 for (idx_t k = i + 1; k < m; ++k)
166 sum -= conj(A(k, i)) * B(k, j);
167 B(i, j) =
168 (diag == Diag::NonUnit) ? sum / conj(A(i, i)) : sum;
169 }
170 }
171 }
172 }
173 }
174 else { // side == Side::Right
175 if (trans == Op::NoTrans) {
176 if (uplo == Uplo::Upper) {
177 for (idx_t j = 0; j < n; ++j) {
178 for (idx_t i = 0; i < m; ++i)
179 B(i, j) *= alpha;
180 for (idx_t k = 0; k < j; ++k) {
181 for (idx_t i = 0; i < m; ++i)
182 B(i, j) -= B(i, k) * A(k, j);
183 }
184 if (diag == Diag::NonUnit) {
185 for (idx_t i = 0; i < m; ++i)
186 B(i, j) /= A(j, j);
187 }
188 }
189 }
190 else { // uplo == Uplo::Lower
191 for (idx_t j = n - 1; j != idx_t(-1); --j) {
192 for (idx_t i = 0; i < m; ++i)
193 B(i, j) *= alpha;
194 for (idx_t k = j + 1; k < n; ++k) {
195 for (idx_t i = 0; i < m; ++i)
196 B(i, j) -= B(i, k) * A(k, j);
197 }
198 if (diag == Diag::NonUnit) {
199 for (idx_t i = 0; i < m; ++i)
200 B(i, j) /= A(j, j);
201 }
202 }
203 }
204 }
205 else if (trans == Op::Trans) {
206 if (uplo == Uplo::Upper) {
207 for (idx_t k = n - 1; k != idx_t(-1); --k) {
208 if (diag == Diag::NonUnit) {
209 for (idx_t i = 0; i < m; ++i)
210 B(i, k) /= A(k, k);
211 }
212 for (idx_t j = 0; j < k; ++j) {
213 for (idx_t i = 0; i < m; ++i)
214 B(i, j) -= B(i, k) * A(j, k);
215 }
216 for (idx_t i = 0; i < m; ++i)
217 B(i, k) *= alpha;
218 }
219 }
220 else { // uplo == Uplo::Lower
221 for (idx_t k = 0; k < n; ++k) {
222 if (diag == Diag::NonUnit) {
223 for (idx_t i = 0; i < m; ++i)
224 B(i, k) /= A(k, k);
225 }
226 for (idx_t j = k + 1; j < n; ++j) {
227 for (idx_t i = 0; i < m; ++i)
228 B(i, j) -= B(i, k) * A(j, k);
229 }
230 for (idx_t i = 0; i < m; ++i)
231 B(i, k) *= alpha;
232 }
233 }
234 }
235 else { // trans == Op::ConjTrans
236 if (uplo == Uplo::Upper) {
237 for (idx_t k = n - 1; k != idx_t(-1); --k) {
238 if (diag == Diag::NonUnit) {
239 for (idx_t i = 0; i < m; ++i)
240 B(i, k) /= conj(A(k, k));
241 }
242 for (idx_t j = 0; j < k; ++j) {
243 for (idx_t i = 0; i < m; ++i)
244 B(i, j) -= B(i, k) * conj(A(j, k));
245 }
246 for (idx_t i = 0; i < m; ++i)
247 B(i, k) *= alpha;
248 }
249 }
250 else { // uplo == Uplo::Lower
251 for (idx_t k = 0; k < n; ++k) {
252 if (diag == Diag::NonUnit) {
253 for (idx_t i = 0; i < m; ++i)
254 B(i, k) /= conj(A(k, k));
255 }
256 for (idx_t j = k + 1; j < n; ++j) {
257 for (idx_t i = 0; i < m; ++i)
258 B(i, j) -= B(i, k) * conj(A(j, k));
259 }
260 for (idx_t i = 0; i < m; ++i)
261 B(i, k) *= alpha;
262 }
263 }
264 }
265 }
266}
267
268#ifdef TLAPACK_USE_LAPACKPP
269
270template <TLAPACK_LEGACY_MATRIX matrixA_t,
271 TLAPACK_LEGACY_MATRIX matrixB_t,
272 TLAPACK_SCALAR alpha_t,
273 class T = type_t<matrixB_t>,
274 enable_if_allow_optblas_t<pair<matrixA_t, T>,
275 pair<matrixB_t, T>,
276 pair<alpha_t, T> > = 0>
277void trsm(Side side,
278 Uplo uplo,
279 Op trans,
280 Diag diag,
281 const alpha_t alpha,
282 const matrixA_t& A,
283 matrixB_t& B)
284{
285 // Legacy objects
286 auto A_ = legacy_matrix(A);
287 auto B_ = legacy_matrix(B);
288
289 // Constants to forward
290 constexpr Layout L = layout<matrixB_t>;
291 const auto& m = B_.m;
292 const auto& n = B_.n;
293
294 // Warnings for NaNs and Infs
295 if (alpha == alpha_t(0))
297 -5, "Infs and NaNs in A or B will not propagate to B on output");
298
299 return ::blas::trsm((::blas::Layout)L, (::blas::Side)side,
300 (::blas::Uplo)uplo, (::blas::Op)trans,
301 (::blas::Diag)diag, m, n, alpha, A_.ptr, A_.ldim,
302 B_.ptr, B_.ldim);
303}
304
305#endif
306
307} // namespace tlapack
308
309#endif // #ifndef TLAPACK_BLAS_TRSM_HH
Diag
Definition types.hpp:192
Side
Definition types.hpp:266
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
void trsm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Solve the triangular matrix-vector equation.
Definition trsm.hpp:76
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113