<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gemv.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_GEMV_HH
12#define TLAPACK_BLAS_GEMV_HH
13
16
17namespace tlapack {
18
46template <TLAPACK_MATRIX matrixA_t,
47 TLAPACK_VECTOR vectorX_t,
48 TLAPACK_VECTOR vectorY_t,
49 TLAPACK_SCALAR alpha_t,
50 TLAPACK_SCALAR beta_t,
51 class T = type_t<vectorY_t>,
52 disable_if_allow_optblas_t<pair<alpha_t, T>,
53 pair<matrixA_t, T>,
54 pair<vectorX_t, T>,
55 pair<vectorY_t, T>,
56 pair<beta_t, T> > = 0>
58 const alpha_t& alpha,
59 const matrixA_t& A,
60 const vectorX_t& x,
61 const beta_t& beta,
62 vectorY_t& y)
63{
64 // data traits
65 using TA = type_t<matrixA_t>;
66 using TX = type_t<vectorX_t>;
67 using idx_t = size_type<matrixA_t>;
68
69 // constants
70 const idx_t m =
71 (trans == Op::NoTrans || trans == Op::Conj) ? nrows(A) : ncols(A);
72 const idx_t n =
73 (trans == Op::NoTrans || trans == Op::Conj) ? ncols(A) : nrows(A);
74
75 // check arguments
76 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
77 trans != Op::ConjTrans && trans != Op::Conj);
78 tlapack_check_false((idx_t)size(x) != n);
79 tlapack_check_false((idx_t)size(y) != m);
80
81 // quick return
82 if (m == 0 || n == 0) return;
83
84 // form y := beta*y
85 for (idx_t i = 0; i < m; ++i)
86 y[i] *= beta;
87
88 if (trans == Op::NoTrans) {
89 // form y += alpha * A * x
90 for (idx_t j = 0; j < n; ++j) {
92 for (idx_t i = 0; i < m; ++i) {
93 y[i] += tmp * A(i, j);
94 }
95 }
96 }
97 else if (trans == Op::Conj) {
98 // form y += alpha * conj( A ) * x
99 for (idx_t j = 0; j < n; ++j) {
100 const scalar_type<alpha_t, TX> tmp = alpha * x[j];
101 for (idx_t i = 0; i < m; ++i) {
102 y[i] += tmp * conj(A(i, j));
103 }
104 }
105 }
106 else if (trans == Op::Trans) {
107 // form y += alpha * A^T * x
108 for (idx_t i = 0; i < m; ++i) {
110 for (idx_t j = 0; j < n; ++j) {
111 tmp += A(j, i) * x[j];
112 }
113 y[i] += alpha * tmp;
114 }
115 }
116 else {
117 // form y += alpha * A^H * x
118 for (idx_t i = 0; i < m; ++i) {
120 for (idx_t j = 0; j < n; ++j) {
121 tmp += conj(A(j, i)) * x[j];
122 }
123 y[i] += alpha * tmp;
124 }
125 }
126}
127
128#ifdef TLAPACK_USE_LAPACKPP
129
142template <TLAPACK_LEGACY_MATRIX matrixA_t,
143 TLAPACK_LEGACY_VECTOR vectorX_t,
144 TLAPACK_LEGACY_VECTOR vectorY_t,
145 TLAPACK_SCALAR alpha_t,
146 TLAPACK_SCALAR beta_t,
147 class T = type_t<vectorY_t>,
148 enable_if_allow_optblas_t<pair<alpha_t, T>,
149 pair<matrixA_t, T>,
150 pair<vectorX_t, T>,
151 pair<vectorY_t, T>,
152 pair<beta_t, T> > = 0>
153void gemv(Op trans,
154 const alpha_t alpha,
155 const matrixA_t& A,
156 const vectorX_t& x,
157 const beta_t beta,
158 vectorY_t& y)
159{
160 using idx_t = size_type<matrixA_t>;
161
162 // Legacy objects
163 auto A_ = legacy_matrix(A);
164 auto x_ = legacy_vector(x);
165 auto y_ = legacy_vector(y);
166
167 // Constants to forward
168 constexpr Layout L = layout<matrixA_t>;
169 const auto& m = A_.m;
170 const auto& n = A_.n;
171
172 // Warnings for NaNs and Infs
173 if (alpha == alpha_t(0))
175 -2, "Infs and NaNs in A or x will not propagate to y on output");
176 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
178 -5,
179 "Infs and NaNs in y on input will not propagate to y on output");
180
181 if (trans != Op::Conj)
182 ::blas::gemv((::blas::Layout)L, (::blas::Op)trans, m, n, alpha, A_.ptr,
183 A_.ldim, x_.ptr, x_.inc, (T)beta, y_.ptr, y_.inc);
184 else {
185 T* x2 = const_cast<T*>(x_.ptr);
186 for (idx_t i = 0; i < x_.n; ++i)
187 x2[i * x_.inc] = conj(x2[i * x_.inc]);
188 conjugate(y);
189 ::blas::gemv((::blas::Layout)L, ::blas::Op::NoTrans, m, n, conj(alpha),
190 A_.ptr, A_.ldim, x_.ptr, x_.inc, conj((T)beta), y_.ptr,
191 y_.inc);
192 for (idx_t i = 0; i < x_.n; ++i)
193 x2[i * x_.inc] = conj(x2[i * x_.inc]);
194 conjugate(y);
195 }
196}
197
198#endif
199
226template <TLAPACK_MATRIX matrixA_t,
227 TLAPACK_VECTOR vectorX_t,
228 TLAPACK_VECTOR vectorY_t,
229 TLAPACK_SCALAR alpha_t>
231 const alpha_t& alpha,
232 const matrixA_t& A,
233 const vectorX_t& x,
234 vectorY_t& y)
235{
236 return gemv(trans, alpha, A, x, StrongZero(), y);
237}
238
239} // namespace tlapack
240
241#endif // #ifndef TLAPACK_BLAS_GEMV_HH
Op
Definition types.hpp:222
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_VECTOR
Macro for tlapack::concepts::LegacyVector compatible with C++17.
Definition concepts.hpp:954
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void conjugate(vector_t &x)
Conjugates a vector.
Definition conjugate.hpp:24
void gemv(Op trans, const alpha_t &alpha, const matrixA_t &A, const vectorX_t &x, const beta_t &beta, vectorY_t &y)
General matrix-vector multiply:
Definition gemv.hpp:57
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Op.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43