<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trsv.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_TRSV_HH
12#define TLAPACK_BLAS_TRSV_HH
13
16
17namespace tlapack {
18
59template <
60 TLAPACK_MATRIX matrixA_t,
61 TLAPACK_VECTOR vectorX_t,
62 class T = type_t<vectorX_t>,
63 disable_if_allow_optblas_t<pair<matrixA_t, T>, pair<vectorX_t, T> > = 0>
64void trsv(Uplo uplo, Op trans, Diag diag, const matrixA_t& A, vectorX_t& x)
65{
66 // data traits
67 using TA = type_t<matrixA_t>;
68 using TX = type_t<vectorX_t>;
69 using idx_t = size_type<matrixA_t>;
70
71 // constants
72 const idx_t n = nrows(A);
73 const bool nonunit = (diag == Diag::NonUnit);
74
75 // check arguments
76 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper);
77 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
78 trans != Op::ConjTrans && trans != Op::Conj);
79 tlapack_check_false(diag != Diag::NonUnit && diag != Diag::Unit);
80 tlapack_check_false(nrows(A) != ncols(A));
81 tlapack_check_false(size(x) != n);
82
83 if (trans == Op::NoTrans) {
84 // Form x := A^{-1} * x
85 if (uplo == Uplo::Upper) {
86 // upper
87 for (idx_t j = n - 1; j != idx_t(-1); --j) {
88 // note: NOT skipping if x[j] is zero, for consistent NAN
89 // handling
90 if (nonunit) {
91 x[j] /= A(j, j);
92 }
93 for (idx_t i = j - 1; i != idx_t(-1); --i) {
94 x[i] -= x[j] * A(i, j);
95 }
96 }
97 }
98 else {
99 // lower
100 for (idx_t j = 0; j < n; ++j) {
101 // note: NOT skipping if x[j] is zero ...
102 if (nonunit) {
103 x[j] /= A(j, j);
104 }
105 for (idx_t i = j + 1; i < n; ++i) {
106 x[i] -= x[j] * A(i, j);
107 }
108 }
109 }
110 }
111 else if (trans == Op::Conj) {
112 // Form x := A^{-1} * x
113 if (uplo == Uplo::Upper) {
114 // upper
115 for (idx_t j = n - 1; j != idx_t(-1); --j) {
116 // note: NOT skipping if x[j] is zero, for consistent NAN
117 // handling
118 if (nonunit) {
119 x[j] /= conj(A(j, j));
120 }
121 for (idx_t i = j - 1; i != idx_t(-1); --i) {
122 x[i] -= x[j] * conj(A(i, j));
123 }
124 }
125 }
126 else {
127 // lower
128 for (idx_t j = 0; j < n; ++j) {
129 // note: NOT skipping if x[j] is zero ...
130 if (nonunit) {
131 x[j] /= conj(A(j, j));
132 }
133 for (idx_t i = j + 1; i < n; ++i) {
134 x[i] -= x[j] * conj(A(i, j));
135 }
136 }
137 }
138 }
139 else if (trans == Op::Trans) {
140 // Form x := A^{-T} * x
141
143
144 if (uplo == Uplo::Upper) {
145 // upper
146 for (idx_t j = 0; j < n; ++j) {
147 scalar_t tmp = x[j];
148 for (idx_t i = 0; i < j; ++i) {
149 tmp -= A(i, j) * x[i];
150 }
151 if (nonunit) {
152 tmp /= A(j, j);
153 }
154 x[j] = tmp;
155 }
156 }
157 else {
158 // lower
159 for (idx_t j = n - 1; j != idx_t(-1); --j) {
160 scalar_t tmp = x[j];
161 for (idx_t i = j + 1; i < n; ++i) {
162 tmp -= A(i, j) * x[i];
163 }
164 if (nonunit) {
165 tmp /= A(j, j);
166 }
167 x[j] = tmp;
168 }
169 }
170 }
171 else {
172 // Form x := A^{-H} * x
173 // same code as above A^{-T} * x case, except add conj()
174
176
177 if (uplo == Uplo::Upper) {
178 // upper
179 for (idx_t j = 0; j < n; ++j) {
180 scalar_t tmp = x[j];
181 for (idx_t i = 0; i < j; ++i) {
182 tmp -= conj(A(i, j)) * x[i];
183 }
184 if (nonunit) {
185 tmp /= conj(A(j, j));
186 }
187 x[j] = tmp;
188 }
189 }
190 else {
191 // lower
192 for (idx_t j = n - 1; j != idx_t(-1); --j) {
193 scalar_t tmp = x[j];
194 for (idx_t i = j + 1; i < n; ++i) {
195 tmp -= conj(A(i, j)) * x[i];
196 }
197 if (nonunit) {
198 tmp /= conj(A(j, j));
199 }
200 x[j] = tmp;
201 }
202 }
203 }
204}
205
206#ifdef TLAPACK_USE_LAPACKPP
207
208template <
209 TLAPACK_LEGACY_MATRIX matrixA_t,
210 TLAPACK_LEGACY_VECTOR vectorX_t,
211 class T = type_t<vectorX_t>,
212 enable_if_allow_optblas_t<pair<matrixA_t, T>, pair<vectorX_t, T> > = 0>
213void trsv(Uplo uplo, Op trans, Diag diag, const matrixA_t& A, vectorX_t& x)
214{
215 // Legacy objects
216 auto A_ = legacy_matrix(A);
217 auto x_ = legacy_vector(x);
218
219 // Constants to forward
220 constexpr Layout L = layout<matrixA_t>;
221 const auto& n = A_.n;
222
223 if (trans != Op::Conj)
224 ::blas::trsv((::blas::Layout)L, (::blas::Uplo)uplo, (::blas::Op)trans,
225 (::blas::Diag)diag, n, A_.ptr, A_.ldim, x_.ptr, x_.inc);
226 else {
227 conjugate(x);
228 ::blas::trsv((::blas::Layout)L, (::blas::Uplo)uplo, ::blas::Op::NoTrans,
229 (::blas::Diag)diag, n, A_.ptr, A_.ldim, x_.ptr, x_.inc);
230 conjugate(x);
231 }
232}
233
234#endif
235
236} // namespace tlapack
237
238#endif // #ifndef TLAPACK_BLAS_TRSV_HH
Diag
Definition types.hpp:192
Op
Definition types.hpp:222
Uplo
Definition types.hpp:45
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_LEGACY_VECTOR
Macro for tlapack::concepts::LegacyVector compatible with C++17.
Definition concepts.hpp:954
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
void conjugate(vector_t &x)
Conjugates a vector.
Definition conjugate.hpp:24
void trsv(Uplo uplo, Op trans, Diag diag, const matrixA_t &A, vectorX_t &x)
Solve the triangular matrix-vector equation.
Definition trsv.hpp:64
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
Concept for types that represent tlapack::Diag.
Concept for types that represent tlapack::Op.
Concept for types that represent tlapack::Uplo.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113