<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
hemm2.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_HEMM_2_HH
12#define TLAPACK_BLAS_HEMM_2_HH
13
15
16namespace tlapack {
17
68template <TLAPACK_MATRIX matrixA_t,
69 TLAPACK_MATRIX matrixB_t,
70 TLAPACK_MATRIX matrixC_t,
71 TLAPACK_SCALAR alpha_t,
72 TLAPACK_SCALAR beta_t,
73 class T = type_t<matrixC_t>>
74
76 Uplo uplo,
77 Op transB,
78 const alpha_t& alpha,
79 const matrixA_t& A,
80 const matrixB_t& B,
81 const beta_t& beta,
82 matrixC_t& C)
83{
84 // data traits
85 using TA = type_t<matrixA_t>;
86 using TB = type_t<matrixB_t>;
87 using idx_t = size_type<matrixB_t>;
88
89 // constants
90 const idx_t m = nrows(B);
91 const idx_t n = ncols(B);
92
93 // // check arguments
96 uplo != GENERAL);
97 tlapack_check_false(nrows(A) != ncols(A));
98 if ((side == LEFT_SIDE && transB == NO_TRANS) ||
99 (side == RIGHT_SIDE && (transB == TRANSPOSE || transB == CONJ_TRANS))) {
100 tlapack_check_false(ncols(A) != m);
101 }
102 else {
103 tlapack_check_false(nrows(A) != n);
104 }
105 if (transB == NO_TRANS) {
106 tlapack_check_false(nrows(C) != m);
107 tlapack_check_false(ncols(C) != n);
108 }
109 else {
110 tlapack_check_false(nrows(C) != n);
111 tlapack_check_false(ncols(C) != m);
112 }
113
114 if (side == LEFT_SIDE) {
115 if (transB == NO_TRANS) {
116 if (uplo == UPPER_TRIANGLE) {
117 // or uplo == Uplo::General
118 for (idx_t j = 0; j < n; ++j) {
119 for (idx_t i = 0; i < m; ++i) {
121 alpha * B(i, j);
123
124 for (idx_t k = 0; k < i; ++k) {
125 C(k, j) += A(k, i) * alphaTimesBij;
126 sum += conj(A(k, i)) * B(k, j);
127 }
128 C(i, j) = beta * C(i, j) +
129 real(A(i, i)) * alphaTimesBij + alpha * sum;
130 }
131 }
132 }
133 else {
134 // uplo == LOWER_TRIANGLE
135 for (idx_t j = 0; j < n; ++j) {
136 for (idx_t i = m - 1; i != idx_t(-1); --i) {
138 alpha * B(i, j);
140
141 for (idx_t k = i + 1; k < m; ++k) {
142 C(k, j) += A(k, i) * alphaTimesBij;
143 sum += conj(A(k, i)) * B(k, j);
144 }
145 C(i, j) = beta * C(i, j) +
146 real(A(i, i)) * alphaTimesBij + alpha * sum;
147 }
148 }
149 }
150 }
151 else if (transB == TRANSPOSE) {
152 // Trans
153 if (uplo == UPPER_TRIANGLE) {
154 // or uplo == Uplo::General
155 for (idx_t j = 0; j < n; j++) {
156 for (idx_t k = 0; k < m; k++) {
157 T sum(0);
158 for (idx_t i = 0; i < j; i++) {
159 sum += conj(A(i, j)) * B(k, i);
160 }
161 sum += real(A(j, j)) * B(k, j);
162 for (idx_t i = j + 1; i < n; i++) {
163 sum += A(j, i) * B(k, i);
164 }
165 C(j, k) = alpha * sum + beta * C(j, k);
166 }
167 }
168 }
169 else {
170 // uplo == LOWER_TRIANGLE
171 for (idx_t j = 0; j < n; j++) {
172 for (idx_t k = 0; k < m; k++) {
173 T sum(0);
174 for (idx_t i = 0; i < j; i++) {
175 sum += A(j, i) * B(k, i);
176 }
177 sum += real(A(j, j)) * B(k, j);
178 for (idx_t i = j + 1; i < n; i++) {
179 sum += conj(A(i, j)) * B(k, i);
180 }
181 C(j, k) = alpha * sum + beta * C(j, k);
182 }
183 }
184 }
185 }
186 else {
187 // TransConj
188 if (uplo == UPPER_TRIANGLE) {
189 // or uplo == Uplo::General
190 for (idx_t j = 0; j < n; j++) {
191 for (idx_t k = 0; k < m; k++) {
192 T sum(0);
193 for (idx_t i = 0; i < j; i++) {
194 sum += conj(A(i, j)) * conj(B(k, i));
195 }
196 sum += real(A(j, j)) * conj(B(k, j));
197 for (idx_t i = j + 1; i < n; i++) {
198 sum += A(j, i) * conj(B(k, i));
199 }
200 C(j, k) = alpha * sum + beta * C(j, k);
201 }
202 }
203 }
204 else {
205 // uplo == LOWER_TRIANGLE
206 for (idx_t j = 0; j < n; j++) {
207 for (idx_t k = 0; k < m; k++) {
208 T sum(0);
209 for (idx_t i = 0; i < j; i++) {
210 sum += A(j, i) * conj(B(k, i));
211 }
212 sum += real(A(j, j)) * conj(B(k, j));
213 for (idx_t i = j + 1; i < n; i++) {
214 sum += conj(A(i, j)) * conj(B(k, i));
215 }
216 C(j, k) = alpha * sum + beta * C(j, k);
217 }
218 }
219 }
220 }
221 }
222 else { // side == RIGHT_SIDE
224
225 if (transB == NO_TRANS) {
226 if (uplo != LOWER_TRIANGLE) {
227 // uplo == UPPER_TRIANGLE or uplo == Uplo::General
228 for (idx_t j = 0; j < n; ++j) {
229 {
230 const scalar_t alphaTimesAjj = alpha * real(A(j, j));
231 for (idx_t i = 0; i < m; ++i)
232 C(i, j) = beta * C(i, j) + B(i, j) * alphaTimesAjj;
233 }
234
235 for (idx_t k = 0; k < j; ++k) {
236 const scalar_t alphaTimesAkj = alpha * A(k, j);
237 for (idx_t i = 0; i < m; ++i)
238 C(i, j) += B(i, k) * alphaTimesAkj;
239 }
240
241 for (idx_t k = j + 1; k < n; ++k) {
242 const scalar_t alphaTimesAjk = alpha * conj(A(j, k));
243 for (idx_t i = 0; i < m; ++i)
244 C(i, j) += B(i, k) * alphaTimesAjk;
245 }
246 }
247 }
248 else {
249 // uplo == LOWER_TRIANGLE
250 for (idx_t j = 0; j < n; ++j) {
251 {
252 const scalar_t alphaTimesAjj = alpha * real(A(j, j));
253 for (idx_t i = 0; i < m; ++i)
254 C(i, j) = beta * C(i, j) + B(i, j) * alphaTimesAjj;
255 }
256
257 for (idx_t k = 0; k < j; ++k) {
258 const scalar_t alphaTimesAjk = alpha * conj(A(j, k));
259 for (idx_t i = 0; i < m; ++i)
260 C(i, j) += B(i, k) * alphaTimesAjk;
261 }
262
263 for (idx_t k = j + 1; k < n; ++k) {
264 const scalar_t alphaTimesAkj = alpha * A(k, j);
265 for (idx_t i = 0; i < m; ++i)
266 C(i, j) += B(i, k) * alphaTimesAkj;
267 }
268 }
269 }
270 }
271 else if (transB == TRANSPOSE) {
272 // Trans
273 if (uplo == UPPER_TRIANGLE) {
274 // or uplo == Uplo::General
275 for (idx_t j = 0; j < n; j++) {
276 for (idx_t k = 0; k < m; k++) {
277 T sum(0);
278 for (idx_t i = 0; i < k; i++) {
279 sum += B(i, j) * A(i, k);
280 }
281 sum += B(k, j) * real(conj(A(k, k)));
282 for (idx_t i = k + 1; i < m; i++) {
283 sum += B(i, j) * conj(A(k, i));
284 }
285 C(j, k) = alpha * sum + beta * C(j, k);
286 }
287 }
288 }
289 else {
290 // uplo == LOWER_TRIANGLE
291 for (idx_t j = 0; j < n; j++) {
292 for (idx_t k = 0; k < m; k++) {
293 T sum(0);
294 for (idx_t i = 0; i < k; i++) {
295 sum += B(i, j) * conj(A(k, i));
296 }
297 sum += B(k, j) * real(A(k, k));
298 for (idx_t i = k + 1; i < m; i++) {
299 sum += B(i, j) * A(i, k);
300 }
301 C(j, k) = alpha * sum + beta * C(j, k);
302 }
303 }
304 }
305 }
306 else {
307 // TransConj
308 if (uplo == UPPER_TRIANGLE) {
309 // or uplo == Uplo::General
310 for (idx_t j = 0; j < n; j++) {
311 for (idx_t k = 0; k < m; k++) {
312 T sum(0);
313 for (idx_t i = 0; i < k; i++) {
314 sum += conj(B(i, j)) * A(i, k);
315 }
316 sum += conj(B(k, j)) * real(conj(A(k, k)));
317 for (idx_t i = k + 1; i < m; i++) {
318 sum += conj(B(i, j)) * conj(A(k, i));
319 }
320 C(j, k) = alpha * sum + beta * C(j, k);
321 }
322 }
323 }
324 else {
325 // uplo == LOWER_TRIANGLE
326 for (idx_t j = 0; j < n; j++) {
327 for (idx_t k = 0; k < m; k++) {
328 T sum(0);
329 for (idx_t i = 0; i < k; i++) {
330 sum += conj(B(i, j)) * conj(A(k, i));
331 }
332 sum += conj(B(k, j)) * real(A(k, k));
333 for (idx_t i = k + 1; i < m; i++) {
334 sum += conj(B(i, j)) * A(i, k);
335 }
336 C(j, k) = alpha * sum + beta * C(j, k);
337 }
338 }
339 }
340 }
341 }
342}
343
394template <TLAPACK_MATRIX matrixA_t,
395 TLAPACK_MATRIX matrixB_t,
396 TLAPACK_MATRIX matrixC_t,
397 TLAPACK_SCALAR alpha_t,
398 class T = type_t<matrixC_t>,
399 disable_if_allow_optblas_t<pair<matrixA_t, T>,
400 pair<matrixB_t, T>,
401 pair<matrixC_t, T>,
402 pair<alpha_t, T>>>
404 Uplo uplo,
405 Op transB,
406 const alpha_t& alpha,
407 const matrixA_t& A,
408 const matrixB_t& B,
409 matrixC_t& C)
410{
411 return hemm2(side, uplo, alpha, A, B, StrongZero(), C);
412}
413
414} // namespace tlapack
415
416#endif // #ifndef TLAPACK_BLAS_HEMM_2_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:188
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:186
Side
Definition types.hpp:271
constexpr internal::RightSide RIGHT_SIDE
right side
Definition types.hpp:296
constexpr internal::Transpose TRANSPOSE
transpose
Definition types.hpp:262
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:180
Op
Definition types.hpp:227
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:264
Uplo
Definition types.hpp:50
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:260
constexpr internal::LeftSide LEFT_SIDE
left side
Definition types.hpp:294
constexpr real_type< T > real(const T &x) noexcept
Extends std::real() to real datatypes.
Definition utils.hpp:71
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void hemm2(Side side, Uplo uplo, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian matrix-Hermitian matrix multiply:
Definition hemm2.hpp:75
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43