<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
her2k.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_HER2K_HH
12#define TLAPACK_BLAS_HER2K_HH
13
15
16namespace tlapack {
17
57template <TLAPACK_MATRIX matrixA_t,
58 TLAPACK_MATRIX matrixB_t,
59 TLAPACK_MATRIX matrixC_t,
60 TLAPACK_SCALAR alpha_t,
61 TLAPACK_REAL beta_t,
62 enable_if_t<(
63 /* Requires: */
64 is_real<beta_t>),
65 int> = 0,
66 class T = type_t<matrixC_t>,
67 disable_if_allow_optblas_t<pair<matrixA_t, T>,
68 pair<matrixB_t, T>,
69 pair<matrixC_t, T>,
70 pair<alpha_t, T>,
71 pair<beta_t, real_type<T> > > = 0>
73 Op trans,
74 const alpha_t& alpha,
75 const matrixA_t& A,
76 const matrixB_t& B,
77 const beta_t& beta,
78 matrixC_t& C)
79{
80 // data traits
81 using TA = type_t<matrixA_t>;
82 using TB = type_t<matrixB_t>;
83 using TC = type_t<matrixC_t>;
84 using idx_t = size_type<matrixA_t>;
85
86 // constants
87 const idx_t n = (trans == Op::NoTrans) ? nrows(A) : ncols(A);
88 const idx_t k = (trans == Op::NoTrans) ? ncols(A) : nrows(A);
89
90 // check arguments
91 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper &&
92 uplo != Uplo::General);
93 tlapack_check_false(trans != Op::NoTrans && trans != Op::ConjTrans);
94 tlapack_check_false(nrows(B) != nrows(A) || ncols(B) != ncols(A));
95 tlapack_check_false(nrows(C) != ncols(C));
96 tlapack_check_false(nrows(C) != n);
97
98 if (trans == Op::NoTrans) {
99 if (uplo != Uplo::Lower) {
100 // uplo == Uplo::Upper or uplo == Uplo::General
101 for (idx_t j = 0; j < n; ++j) {
102 for (idx_t i = 0; i < j; ++i)
103 C(i, j) *= beta;
104 C(j, j) = TC(beta * real(C(j, j)));
105
106 for (idx_t l = 0; l < k; ++l) {
108 alpha * conj(B(j, l));
110 conj(alpha * A(j, l));
111
112 for (idx_t i = 0; i < j; ++i)
113 C(i, j) +=
114 A(i, l) * alphaConjBjl + B(i, l) * conjAlphaAjl;
115 C(j, j) += 2 * (real(A(j, l)) * real(alphaConjBjl) -
116 imag(A(j, l)) * imag(alphaConjBjl));
117 }
118 }
119 }
120 else { // uplo == Uplo::Lower
121 for (idx_t j = 0; j < n; ++j) {
122 C(j, j) = TC(beta * real(C(j, j)));
123 for (idx_t i = j + 1; i < n; ++i)
124 C(i, j) *= beta;
125
126 for (idx_t l = 0; l < k; ++l) {
128 alpha * conj(B(j, l));
130 conj(alpha * A(j, l));
131
132 C(j, j) += 2 * (real(A(j, l)) * real(alphaConjBjl) -
133 imag(A(j, l)) * imag(alphaConjBjl));
134 for (idx_t i = j + 1; i < n; ++i)
135 C(i, j) +=
136 A(i, l) * alphaConjBjl + B(i, l) * conjAlphaAjl;
137 }
138 }
139 }
140 }
141 else { // trans == Op::ConjTrans
143
144 if (uplo != Uplo::Lower) {
145 // uplo == Uplo::Upper or uplo == Uplo::General
146 for (idx_t j = 0; j < n; ++j) {
147 for (idx_t i = 0; i <= j; ++i) {
148 scalar_t sum1(0);
149 scalar_t sum2(0);
150 for (idx_t l = 0; l < k; ++l) {
151 sum1 += conj(A(l, i)) * B(l, j);
152 sum2 += conj(B(l, i)) * A(l, j);
153 }
154
155 C(i, j) = (i < j) ? alpha * sum1 + conj(alpha) * sum2 +
156 beta * C(i, j)
157 : real(alpha) * real(sum1) -
158 imag(alpha) * imag(sum1) +
159 real(alpha) * real(sum2) +
160 imag(alpha) * imag(sum2) +
161 beta * real(C(i, j));
162 }
163 }
164 }
165 else { // uplo == Uplo::Lower
166 for (idx_t j = 0; j < n; ++j) {
167 for (idx_t i = j; i < n; ++i) {
168 scalar_t sum1(0);
169 scalar_t sum2(0);
170 for (idx_t l = 0; l < k; ++l) {
171 sum1 += conj(A(l, i)) * B(l, j);
172 sum2 += conj(B(l, i)) * A(l, j);
173 }
174
175 C(i, j) = (i > j) ? alpha * sum1 + conj(alpha) * sum2 +
176 beta * C(i, j)
177 : real(alpha) * real(sum1) -
178 imag(alpha) * imag(sum1) +
179 real(alpha) * real(sum2) +
180 imag(alpha) * imag(sum2) +
181 beta * real(C(i, j));
182 }
183 }
184 }
185 }
186
187 if (uplo == Uplo::General) {
188 for (idx_t j = 0; j < n; ++j) {
189 for (idx_t i = j + 1; i < n; ++i)
190 C(i, j) = conj(C(j, i));
191 }
192 }
193}
194
195#ifdef TLAPACK_USE_LAPACKPP
196
210template <TLAPACK_LEGACY_MATRIX matrixA_t,
211 TLAPACK_LEGACY_MATRIX matrixB_t,
212 TLAPACK_LEGACY_MATRIX matrixC_t,
213 TLAPACK_SCALAR alpha_t,
214 TLAPACK_REAL beta_t,
215 enable_if_t<(
216 /* Requires: */
217 is_real<beta_t>),
218 int> = 0,
219 class T = type_t<matrixC_t>,
220 enable_if_allow_optblas_t<pair<matrixA_t, T>,
221 pair<matrixB_t, T>,
222 pair<matrixC_t, T>,
223 pair<alpha_t, T>,
224 pair<beta_t, real_type<T> > > = 0>
225void her2k(Uplo uplo,
226 Op trans,
227 const alpha_t alpha,
228 const matrixA_t& A,
229 const matrixB_t& B,
230 const beta_t beta,
231 matrixC_t& C)
232{
233 // Legacy objects
234 auto A_ = legacy_matrix(A);
235 auto B_ = legacy_matrix(B);
236 auto C_ = legacy_matrix(C);
237
238 // Constants to forward
239 constexpr Layout L = layout<matrixC_t>;
240 const auto& n = C_.n;
241 const auto& k = (trans == Op::NoTrans) ? A_.n : A_.m;
242
243 // Warnings for NaNs and Infs
244 if (alpha == alpha_t(0))
246 -3, "Infs and NaNs in A or B will not propagate to C on output");
247 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
249 -6,
250 "Infs and NaNs in C on input will not propagate to C on output");
251
252 return ::blas::her2k((::blas::Layout)L, (::blas::Uplo)uplo,
253 (::blas::Op)trans, n, k, alpha, A_.ptr, A_.ldim,
254 B_.ptr, B_.ldim, (real_type<T>)beta, C_.ptr, C_.ldim);
255}
256
257#endif
258
295template <TLAPACK_MATRIX matrixA_t,
296 TLAPACK_MATRIX matrixB_t,
297 TLAPACK_MATRIX matrixC_t,
298 TLAPACK_SCALAR alpha_t>
300 Op trans,
301 const alpha_t& alpha,
302 const matrixA_t& A,
303 const matrixB_t& B,
304 matrixC_t& C)
305{
306 return her2k(uplo, trans, alpha, A, B, StrongZero(), C);
307}
308
309} // namespace tlapack
310
311#endif // #ifndef TLAPACK_BLAS_HER2K_HH
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
constexpr real_type< T > real(const T &x) noexcept
Extends std::real() to real datatypes.
Definition utils.hpp:71
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
constexpr real_type< T > imag(const T &x) noexcept
Extends std::imag() to real datatypes.
Definition utils.hpp:86
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
#define TLAPACK_REAL
Macro for tlapack::concepts::Real compatible with C++17.
Definition concepts.hpp:918
void her2k(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition her2k.hpp:72
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43