<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
syr2k.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_SYR2K_HH
12#define TLAPACK_BLAS_SYR2K_HH
13
15
16namespace tlapack {
17
53template <TLAPACK_MATRIX matrixA_t,
54 TLAPACK_MATRIX matrixB_t,
55 TLAPACK_MATRIX matrixC_t,
56 TLAPACK_SCALAR alpha_t,
57 TLAPACK_SCALAR beta_t,
58 class T = type_t<matrixC_t>,
59 disable_if_allow_optblas_t<pair<matrixA_t, T>,
60 pair<matrixB_t, T>,
61 pair<matrixC_t, T>,
62 pair<alpha_t, T>,
63 pair<beta_t, T> > = 0>
65 Op trans,
66 const alpha_t& alpha,
67 const matrixA_t& A,
68 const matrixB_t& B,
69 const beta_t& beta,
70 matrixC_t& C)
71{
72 // data traits
73 using TA = type_t<matrixA_t>;
74 using TB = type_t<matrixB_t>;
75 using idx_t = size_type<matrixA_t>;
76
77 // constants
78 const idx_t n = (trans == Op::NoTrans) ? nrows(A) : ncols(A);
79 const idx_t k = (trans == Op::NoTrans) ? ncols(A) : nrows(A);
80
81 // check arguments
82 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper &&
83 uplo != Uplo::General);
84 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans);
85 tlapack_check_false(nrows(B) != nrows(A) || ncols(B) != ncols(A));
86 tlapack_check_false(nrows(C) != ncols(C));
87 tlapack_check_false(nrows(C) != n);
88
89 if (trans == Op::NoTrans) {
90 if (uplo != Uplo::Lower) {
91 // uplo == Uplo::Upper or uplo == Uplo::General
92 for (idx_t j = 0; j < n; ++j) {
93 for (idx_t i = 0; i <= j; ++i)
94 C(i, j) *= beta;
95
96 for (idx_t l = 0; l < k; ++l) {
99 for (idx_t i = 0; i <= j; ++i)
100 C(i, j) += A(i, l) * alphaBjl + B(i, l) * alphaAjl;
101 }
102 }
103 }
104 else { // uplo == Uplo::Lower
105 for (idx_t j = 0; j < n; ++j) {
106 for (idx_t i = j; i < n; ++i)
107 C(i, j) *= beta;
108
109 for (idx_t l = 0; l < k; ++l) {
112 for (idx_t i = j; i < n; ++i)
113 C(i, j) += A(i, l) * alphaBjl + B(i, l) * alphaAjl;
114 }
115 }
116 }
117 }
118 else { // trans == Op::Trans
120
121 if (uplo != Uplo::Lower) {
122 // uplo == Uplo::Upper or uplo == Uplo::General
123 for (idx_t j = 0; j < n; ++j) {
124 for (idx_t i = 0; i <= j; ++i) {
125 scalar_t sum1(0);
126 scalar_t sum2(0);
127 for (idx_t l = 0; l < k; ++l) {
128 sum1 += A(l, i) * B(l, j);
129 sum2 += B(l, i) * A(l, j);
130 }
131 C(i, j) = alpha * sum1 + alpha * sum2 + beta * C(i, j);
132 }
133 }
134 }
135 else { // uplo == Uplo::Lower
136 for (idx_t j = 0; j < n; ++j) {
137 for (idx_t i = j; i < n; ++i) {
138 scalar_t sum1(0);
139 scalar_t sum2(0);
140 for (idx_t l = 0; l < k; ++l) {
141 sum1 += A(l, i) * B(l, j);
142 sum2 += B(l, i) * A(l, j);
143 }
144 C(i, j) = alpha * sum1 + alpha * sum2 + beta * C(i, j);
145 }
146 }
147 }
148 }
149
150 if (uplo == Uplo::General) {
151 for (idx_t j = 0; j < n; ++j) {
152 for (idx_t i = j + 1; i < n; ++i)
153 C(i, j) = C(j, i);
154 }
155 }
156}
157
158#ifdef TLAPACK_USE_LAPACKPP
159
173template <TLAPACK_LEGACY_MATRIX matrixA_t,
174 TLAPACK_LEGACY_MATRIX matrixB_t,
175 TLAPACK_LEGACY_MATRIX matrixC_t,
176 TLAPACK_SCALAR alpha_t,
177 TLAPACK_SCALAR beta_t,
178 class T = type_t<matrixC_t>,
179 enable_if_allow_optblas_t<pair<matrixA_t, T>,
180 pair<matrixB_t, T>,
181 pair<matrixC_t, T>,
182 pair<alpha_t, T>,
183 pair<beta_t, T> > = 0>
184void syr2k(Uplo uplo,
185 Op trans,
186 const alpha_t alpha,
187 const matrixA_t& A,
188 const matrixB_t& B,
189 const beta_t beta,
190 matrixC_t& C)
191{
192 // Legacy objects
193 auto A_ = legacy_matrix(A);
194 auto B_ = legacy_matrix(B);
195 auto C_ = legacy_matrix(C);
196
197 // Constants to forward
198 constexpr Layout L = layout<matrixC_t>;
199 const auto& n = C_.n;
200 const auto& k = (trans == Op::NoTrans) ? A_.n : A_.m;
201
202 // Warnings for NaNs and Infs
203 if (alpha == alpha_t(0))
205 -3, "Infs and NaNs in A or B will not propagate to C on output");
206 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
208 -6,
209 "Infs and NaNs in C on input will not propagate to C on output");
210
211 return ::blas::syr2k((::blas::Layout)L, (::blas::Uplo)uplo,
212 (::blas::Op)trans, n, k, alpha, A_.ptr, A_.ldim,
213 B_.ptr, B_.ldim, (T)beta, C_.ptr, C_.ldim);
214}
215
216#endif
217
252template <TLAPACK_MATRIX matrixA_t,
253 TLAPACK_MATRIX matrixB_t,
254 TLAPACK_MATRIX matrixC_t,
255 TLAPACK_SCALAR alpha_t>
257 Op trans,
258 const alpha_t& alpha,
259 const matrixA_t& A,
260 const matrixB_t& B,
261 matrixC_t& C)
262{
263 return syr2k(uplo, trans, alpha, A, B, StrongZero(), C);
264}
265
266} // namespace tlapack
267
268#endif // #ifndef TLAPACK_BLAS_SYR2K_HH
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void syr2k(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Symmetric rank-k update:
Definition syr2k.hpp:64
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43