<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trmm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_TRMM_HH
12#define TLAPACK_BLAS_TRMM_HH
13
15
16namespace tlapack {
17
65template <TLAPACK_MATRIX matrixA_t,
66 TLAPACK_MATRIX matrixB_t,
67 TLAPACK_SCALAR alpha_t,
68 class T = type_t<matrixB_t>,
69 disable_if_allow_optblas_t<pair<matrixA_t, T>,
70 pair<matrixB_t, T>,
71 pair<alpha_t, T> > = 0>
73 Uplo uplo,
74 Op trans,
75 Diag diag,
76 const alpha_t& alpha,
77 const matrixA_t& A,
78 matrixB_t& B)
79{
80 // data traits
81 using TA = type_t<matrixA_t>;
82 using TB = type_t<matrixB_t>;
83 using idx_t = size_type<matrixA_t>;
84
85 // constants
86 const idx_t m = nrows(B);
87 const idx_t n = ncols(B);
88
89 // check arguments
90 tlapack_check_false(side != Side::Left && side != Side::Right);
91 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper);
92 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
93 trans != Op::ConjTrans);
94 tlapack_check_false(diag != Diag::NonUnit && diag != Diag::Unit);
95 tlapack_check_false(nrows(A) != ncols(A));
96 tlapack_check_false(nrows(A) != ((side == Side::Left) ? m : n));
97
98 if (side == Side::Left) {
99 if (trans == Op::NoTrans) {
101 if (uplo == Uplo::Upper) {
102 for (idx_t j = 0; j < n; ++j) {
103 for (idx_t k = 0; k < m; ++k) {
104 const scalar_t alphaBkj = alpha * B(k, j);
105 for (idx_t i = 0; i < k; ++i)
106 B(i, j) += A(i, k) * alphaBkj;
107 B(k, j) = (diag == Diag::NonUnit) ? A(k, k) * alphaBkj
108 : alphaBkj;
109 }
110 }
111 }
112 else { // uplo == Uplo::Lower
113 for (idx_t j = 0; j < n; ++j) {
114 for (idx_t k = m - 1; k != idx_t(-1); --k) {
115 const scalar_t alphaBkj = alpha * B(k, j);
116 B(k, j) = (diag == Diag::NonUnit) ? A(k, k) * alphaBkj
117 : alphaBkj;
118 for (idx_t i = k + 1; i < m; ++i)
119 B(i, j) += A(i, k) * alphaBkj;
120 }
121 }
122 }
123 }
124 else if (trans == Op::Trans) {
126 if (uplo == Uplo::Upper) {
127 for (idx_t j = 0; j < n; ++j) {
128 for (idx_t i = m - 1; i != idx_t(-1); --i) {
129 scalar_t sum = (diag == Diag::NonUnit)
130 ? A(i, i) * B(i, j)
131 : B(i, j);
132 for (idx_t k = 0; k < i; ++k)
133 sum += A(k, i) * B(k, j);
134 B(i, j) = alpha * sum;
135 }
136 }
137 }
138 else { // uplo == Uplo::Lower
139 for (idx_t j = 0; j < n; ++j) {
140 for (idx_t i = 0; i < m; ++i) {
141 scalar_t sum = (diag == Diag::NonUnit)
142 ? A(i, i) * B(i, j)
143 : B(i, j);
144 for (idx_t k = i + 1; k < m; ++k)
145 sum += A(k, i) * B(k, j);
146 B(i, j) = alpha * sum;
147 }
148 }
149 }
150 }
151 else { // trans == Op::ConjTrans
153 if (uplo == Uplo::Upper) {
154 for (idx_t j = 0; j < n; ++j) {
155 for (idx_t i = m - 1; i != idx_t(-1); --i) {
156 scalar_t sum = (diag == Diag::NonUnit)
157 ? conj(A(i, i)) * B(i, j)
158 : B(i, j);
159 for (idx_t k = 0; k < i; ++k)
160 sum += conj(A(k, i)) * B(k, j);
161 B(i, j) = alpha * sum;
162 }
163 }
164 }
165 else { // uplo == Uplo::Lower
166 for (idx_t j = 0; j < n; ++j) {
167 for (idx_t i = 0; i < m; ++i) {
168 scalar_t sum = (diag == Diag::NonUnit)
169 ? conj(A(i, i)) * B(i, j)
170 : B(i, j);
171 for (idx_t k = i + 1; k < m; ++k)
172 sum += conj(A(k, i)) * B(k, j);
173 B(i, j) = alpha * sum;
174 }
175 }
176 }
177 }
178 }
179 else { // side == Side::Right
181 if (trans == Op::NoTrans) {
182 if (uplo == Uplo::Upper) {
183 for (idx_t j = n - 1; j != idx_t(-1); --j) {
184 {
185 const scalar_t alphaAjj =
186 (diag == Diag::NonUnit) ? alpha * A(j, j) : alpha;
187 for (idx_t i = 0; i < m; ++i)
188 B(i, j) *= alphaAjj;
189 }
190 for (idx_t k = 0; k < j; ++k) {
191 const scalar_t alphaAkj = alpha * A(k, j);
192 for (idx_t i = 0; i < m; ++i)
193 B(i, j) += B(i, k) * alphaAkj;
194 }
195 }
196 }
197 else { // uplo == Uplo::Lower
198 for (idx_t j = 0; j < n; ++j) {
199 {
200 const scalar_t alphaAjj =
201 (diag == Diag::NonUnit) ? alpha * A(j, j) : alpha;
202 for (idx_t i = 0; i < m; ++i)
203 B(i, j) *= alphaAjj;
204 }
205 for (idx_t k = j + 1; k < n; ++k) {
206 const scalar_t alphaAkj = alpha * A(k, j);
207 for (idx_t i = 0; i < m; ++i)
208 B(i, j) += B(i, k) * alphaAkj;
209 }
210 }
211 }
212 }
213 else if (trans == Op::Trans) {
214 if (uplo == Uplo::Upper) {
215 for (idx_t k = 0; k < n; ++k) {
216 for (idx_t j = 0; j < k; ++j) {
217 const scalar_t alphaAjk = alpha * A(j, k);
218 for (idx_t i = 0; i < m; ++i)
219 B(i, j) += B(i, k) * alphaAjk;
220 }
221 {
222 const scalar_t alphaAkk =
223 (diag == Diag::NonUnit) ? alpha * A(k, k) : alpha;
224 for (idx_t i = 0; i < m; ++i)
225 B(i, k) *= alphaAkk;
226 }
227 }
228 }
229 else { // uplo == Uplo::Lower
230 for (idx_t k = n - 1; k != idx_t(-1); --k) {
231 for (idx_t j = k + 1; j < n; ++j) {
232 const scalar_t alphaAjk = alpha * A(j, k);
233 for (idx_t i = 0; i < m; ++i)
234 B(i, j) += B(i, k) * alphaAjk;
235 }
236 {
237 const scalar_t alphaAkk =
238 (diag == Diag::NonUnit) ? alpha * A(k, k) : alpha;
239 for (idx_t i = 0; i < m; ++i)
240 B(i, k) *= alphaAkk;
241 }
242 }
243 }
244 }
245 else { // trans == Op::ConjTrans
246 if (uplo == Uplo::Upper) {
247 for (idx_t k = 0; k < n; ++k) {
248 for (idx_t j = 0; j < k; ++j) {
249 const scalar_t alphaAjk = alpha * conj(A(j, k));
250 for (idx_t i = 0; i < m; ++i)
251 B(i, j) += B(i, k) * alphaAjk;
252 }
253 {
254 const scalar_t alphaAkk = (diag == Diag::NonUnit)
255 ? alpha * conj(A(k, k))
256 : alpha;
257 for (idx_t i = 0; i < m; ++i)
258 B(i, k) *= alphaAkk;
259 }
260 }
261 }
262 else { // uplo == Uplo::Lower
263 for (idx_t k = n - 1; k != idx_t(-1); --k) {
264 for (idx_t j = k + 1; j < n; ++j) {
265 const scalar_t alphaAjk = alpha * conj(A(j, k));
266 for (idx_t i = 0; i < m; ++i)
267 B(i, j) += B(i, k) * alphaAjk;
268 }
269 {
270 const scalar_t alphaAkk = (diag == Diag::NonUnit)
271 ? alpha * conj(A(k, k))
272 : alpha;
273 for (idx_t i = 0; i < m; ++i)
274 B(i, k) *= alphaAkk;
275 }
276 }
277 }
278 }
279 }
280}
281
282#ifdef TLAPACK_USE_LAPACKPP
283
300template <TLAPACK_LEGACY_MATRIX matrixA_t,
301 TLAPACK_LEGACY_MATRIX matrixB_t,
302 TLAPACK_SCALAR alpha_t,
303 class T = type_t<matrixB_t>,
304 enable_if_allow_optblas_t<pair<matrixA_t, T>,
305 pair<matrixB_t, T>,
306 pair<alpha_t, T> > = 0>
307void trmm(Side side,
308 Uplo uplo,
309 Op trans,
310 Diag diag,
311 const alpha_t alpha,
312 const matrixA_t& A,
313 matrixB_t& B)
314{
315 // Legacy objects
316 auto A_ = legacy_matrix(A);
317 auto B_ = legacy_matrix(B);
318
319 // Constants to forward
320 constexpr Layout L = layout<matrixB_t>;
321 const auto& m = B_.m;
322 const auto& n = B_.n;
323
324 // Warnings for NaNs and Infs
325 if (alpha == alpha_t(0))
327 -5, "Infs and NaNs in A or B will not propagate to B on output");
328
329 return ::blas::trmm((::blas::Layout)L, (::blas::Side)side,
330 (::blas::Uplo)uplo, (::blas::Op)trans,
331 (::blas::Diag)diag, m, n, alpha, A_.ptr, A_.ldim,
332 B_.ptr, B_.ldim);
333}
334
335#endif
336
337} // namespace tlapack
338
339#endif // #ifndef TLAPACK_BLAS_TRMM_HH
Diag
Definition types.hpp:192
Side
Definition types.hpp:266
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113