<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
hemm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_HEMM_HH
12#define TLAPACK_BLAS_HEMM_HH
13
15
16namespace tlapack {
17
52template <TLAPACK_MATRIX matrixA_t,
53 TLAPACK_MATRIX matrixB_t,
54 TLAPACK_MATRIX matrixC_t,
55 TLAPACK_SCALAR alpha_t,
56 TLAPACK_SCALAR beta_t,
57 class T = type_t<matrixC_t>,
58 disable_if_allow_optblas_t<pair<matrixA_t, T>,
59 pair<matrixB_t, T>,
60 pair<matrixC_t, T>,
61 pair<alpha_t, T>,
62 pair<beta_t, T> > = 0>
64 Uplo uplo,
65 const alpha_t& alpha,
66 const matrixA_t& A,
67 const matrixB_t& B,
68 const beta_t& beta,
69 matrixC_t& C)
70{
71 // data traits
72 using TA = type_t<matrixA_t>;
73 using TB = type_t<matrixB_t>;
74 using idx_t = size_type<matrixB_t>;
75
76 // constants
77 const idx_t m = nrows(B);
78 const idx_t n = ncols(B);
79
80 // check arguments
81 tlapack_check_false(side != Side::Left && side != Side::Right);
82 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper &&
83 uplo != Uplo::General);
84 tlapack_check_false(nrows(A) != ncols(A));
85 tlapack_check_false(nrows(A) != ((side == Side::Left) ? m : n));
86 tlapack_check_false(nrows(C) != m);
87 tlapack_check_false(ncols(C) != n);
88
89 if (side == Side::Left) {
90 if (uplo != Uplo::Lower) {
91 // uplo == Uplo::Upper or uplo == Uplo::General
92 for (idx_t j = 0; j < n; ++j) {
93 for (idx_t i = 0; i < m; ++i) {
95 alpha * B(i, j);
97
98 for (idx_t k = 0; k < i; ++k) {
99 C(k, j) += A(k, i) * alphaTimesBij;
100 sum += conj(A(k, i)) * B(k, j);
101 }
102 C(i, j) = beta * C(i, j) + real(A(i, i)) * alphaTimesBij +
103 alpha * sum;
104 }
105 }
106 }
107 else {
108 // uplo == Uplo::Lower
109 for (idx_t j = 0; j < n; ++j) {
110 for (idx_t i = m - 1; i != idx_t(-1); --i) {
112 alpha * B(i, j);
114
115 for (idx_t k = i + 1; k < m; ++k) {
116 C(k, j) += A(k, i) * alphaTimesBij;
117 sum += conj(A(k, i)) * B(k, j);
118 }
119 C(i, j) = beta * C(i, j) + real(A(i, i)) * alphaTimesBij +
120 alpha * sum;
121 }
122 }
123 }
124 }
125 else { // side == Side::Right
126
128
129 if (uplo != Uplo::Lower) {
130 // uplo == Uplo::Upper or uplo == Uplo::General
131 for (idx_t j = 0; j < n; ++j) {
132 {
133 const scalar_t alphaTimesAjj = alpha * real(A(j, j));
134 for (idx_t i = 0; i < m; ++i)
135 C(i, j) = beta * C(i, j) + B(i, j) * alphaTimesAjj;
136 }
137
138 for (idx_t k = 0; k < j; ++k) {
139 const scalar_t alphaTimesAkj = alpha * A(k, j);
140 for (idx_t i = 0; i < m; ++i)
141 C(i, j) += B(i, k) * alphaTimesAkj;
142 }
143
144 for (idx_t k = j + 1; k < n; ++k) {
145 const scalar_t alphaTimesAjk = alpha * conj(A(j, k));
146 for (idx_t i = 0; i < m; ++i)
147 C(i, j) += B(i, k) * alphaTimesAjk;
148 }
149 }
150 }
151 else {
152 // uplo == Uplo::Lower
153 for (idx_t j = 0; j < n; ++j) {
154 {
155 const scalar_t alphaTimesAjj = alpha * real(A(j, j));
156 for (idx_t i = 0; i < m; ++i)
157 C(i, j) = beta * C(i, j) + B(i, j) * alphaTimesAjj;
158 }
159
160 for (idx_t k = 0; k < j; ++k) {
161 const scalar_t alphaTimesAjk = alpha * conj(A(j, k));
162 for (idx_t i = 0; i < m; ++i)
163 C(i, j) += B(i, k) * alphaTimesAjk;
164 }
165
166 for (idx_t k = j + 1; k < n; ++k) {
167 const scalar_t alphaTimesAkj = alpha * A(k, j);
168 for (idx_t i = 0; i < m; ++i)
169 C(i, j) += B(i, k) * alphaTimesAkj;
170 }
171 }
172 }
173 }
174}
175
176#ifdef TLAPACK_USE_LAPACKPP
177
191template <TLAPACK_LEGACY_MATRIX matrixA_t,
192 TLAPACK_LEGACY_MATRIX matrixB_t,
193 TLAPACK_LEGACY_MATRIX matrixC_t,
194 TLAPACK_SCALAR alpha_t,
195 TLAPACK_SCALAR beta_t,
196 class T = type_t<matrixC_t>,
197 enable_if_allow_optblas_t<pair<matrixA_t, T>,
198 pair<matrixB_t, T>,
199 pair<matrixC_t, T>,
200 pair<alpha_t, T>,
201 pair<beta_t, T> > = 0>
202void hemm(Side side,
203 Uplo uplo,
204 const alpha_t alpha,
205 const matrixA_t& A,
206 const matrixB_t& B,
207 const beta_t beta,
208 matrixC_t& C)
209{
210 // Legacy objects
211 auto A_ = legacy_matrix(A);
212 auto B_ = legacy_matrix(B);
213 auto C_ = legacy_matrix(C);
214
215 // Constants to forward
216 constexpr Layout L = layout<matrixC_t>;
217 const auto& m = C_.m;
218 const auto& n = C_.n;
219
220 // Warnings for NaNs and Infs
221 if (alpha == alpha_t(0))
223 -3, "Infs and NaNs in A or B will not propagate to C on output");
224 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
226 -6,
227 "Infs and NaNs in C on input will not propagate to C on output");
228
229 return ::blas::hemm((::blas::Layout)L, (::blas::Side)side,
230 (::blas::Uplo)uplo, m, n, alpha, A_.ptr, A_.ldim,
231 B_.ptr, B_.ldim, (T)beta, C_.ptr, C_.ldim);
232}
233
234#endif
235
269template <TLAPACK_MATRIX matrixA_t,
270 TLAPACK_MATRIX matrixB_t,
271 TLAPACK_MATRIX matrixC_t,
272 TLAPACK_SCALAR alpha_t>
274 Uplo uplo,
275 const alpha_t& alpha,
276 const matrixA_t& A,
277 const matrixB_t& B,
278 matrixC_t& C)
279{
280 return hemm(side, uplo, alpha, A, B, StrongZero(), C);
281}
282
283} // namespace tlapack
284
285#endif // #ifndef TLAPACK_BLAS_HEMM_HH
Side
Definition types.hpp:266
Uplo
Definition types.hpp:45
constexpr real_type< T > real(const T &x) noexcept
Extends std::real() to real datatypes.
Definition utils.hpp:71
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void hemm(Side side, Uplo uplo, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian matrix-matrix multiply:
Definition hemm.hpp:63
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Side.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43