<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trevc3_forwardsolve.hpp
Go to the documentation of this file.
1
3// based on A. Schwarz et al., "Scalable eigenvector computation for the
4// non-symmetric eigenvalue problem"
5//
6// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_TREVC3_FORWARDSOLVE_HH
13#define TLAPACK_TREVC3_FORWARDSOLVE_HH
14
16#include "tlapack/blas/asum.hpp"
17#include "tlapack/blas/gemm.hpp"
20
21namespace tlapack {
22
34 work_t& work,
38{
39 using idx_t = size_type<matrix_T_t>;
40 using TT = type_t<matrix_T_t>;
41 using real_t = real_type<TT>;
43
44 const idx_t n = nrows(T);
45
46 tlapack_check(ncols(T) == n);
47
48 idx_t nk = ke - ks;
49 auto [shifts, work2] = reshape(work, nk);
50
51 laset(Uplo::General, TT(0), TT(0), X);
52
53 // Step 1: Calculate the eigenvectors of the ks:ke submatrix
54 // using trevc_forwardsolve_single or trevc_forwardsolve_double
55 // and update the rest of X using a matrix-matrix multiplication
56 {
57 // Calculate the eigenvectors of the ks:ke block
58 auto Tii = slice(T, range(ks, ke), range(ks, ke));
59 auto X_ii = slice(X, range(ks, ke), range(0, nk));
60
61 for (idx_t k = 0; k < nk;) {
62 bool pair = false;
63 if constexpr (is_real<TT>) {
64 if (k + 1 < nk) {
65 if (Tii(k + 1, k) != TT(0)) {
66 pair = true;
67 }
68 }
69 }
70
71 if (pair) {
72 if constexpr (is_real<TT>) {
73 TT alpha = Tii(k, k);
74 TT beta = Tii(k, k + 1);
75 TT gamma = Tii(k + 1, k);
76
77 // real part of eigenvalue
78 TT wr = alpha;
79 // imaginary part of eigenvalue
80 TT wi = sqrt(abs(beta)) * sqrt(abs(gamma));
81
82 shifts[k] = wr;
83 shifts[k + 1] = wi;
84 // todo: remember that k and k+1 form a pair somehow, maybe
85 // a bool array?
86
87 // Complex conjugate pair
88 auto x1 = col(X_ii, k);
89 auto x2 = col(X_ii, k + 1);
90 auto colN_ii = slice(colN, range(ks, ke));
92 }
93 k += 2;
94 }
95 else {
96 shifts[k] = Tii(k, k);
97 // Real eigenvalue
98 auto x1 = col(X_ii, k);
99 auto colN_ii = slice(colN, range(ks, ke));
101 k += 1;
102 }
103 }
104
105 auto T_ij = slice(T, range(ks, ke), range(ke, n));
106 auto X_i = slice(X, range(ke, n), range(0, nk));
107 // This is where you might think to initialize X_i as -T_ij
108 // But the multiplication below takes care of that because it also
109 // takes the element 1 into account
111 }
112
113 // Step 2: Now keep propagating the updates downwards, but keep the
114 // individual shifts in mind
115 for (idx_t iib = ke; iib < n;) {
116 idx_t nb = min(blocksize, n - iib);
117 // Start of the block
118 idx_t bs = iib;
119 // End of the block
120 idx_t be = iib + nb;
121
122 // Make sure we don't split 2x2 blocks
123 if constexpr (is_real<TT>) {
124 if (be < n) {
125 // TODO: find a better way to check this so we don't
126 // always access other blocks of T
127 if (T(be, be - 1) != TT(0)) {
128 be += 1;
129 nb += 1;
130 }
131 }
132 }
133
134 auto T_ii = slice(T, range(bs, be), range(bs, be));
135 auto X_ii = slice(X, range(bs, be), range(0, nk));
136
137 for (idx_t k = 0; k < nk;) {
138 bool pair = false;
139 if constexpr (is_real<TT>) {
140 if (k + 1 < nk) {
141 if (T(ks + k + 1, ks + k) != TT(0)) {
142 pair = true;
143 }
144 }
145 }
146
147 if (pair) {
148 TT wr = shifts[k];
149 TT wi = shifts[k + 1];
150
151 for (idx_t i = 0; i < nb;) {
152 bool is_2x2_block = false;
153 if (i + 1 < nb) {
154 if (T_ii(i + 1, i) != TT(0)) {
155 is_2x2_block = true;
156 }
157 }
158
159 if (is_2x2_block) {
160 // 2x2 block
161
162 for (idx_t j = 0; j < i; ++j) {
163 X_ii(i, k) -= T_ii(j, i) * X_ii(j, k);
164 X_ii(i, k + 1) -= T_ii(j, i) * X_ii(j, k + 1);
165 X_ii(i + 1, k) -= T_ii(j, i + 1) * X_ii(j, k);
166 X_ii(i + 1, k + 1) -=
167 T_ii(j, i + 1) * X_ii(j, k + 1);
168 }
169
170 // Solve the complex 2x2 system
171 // Using real arithmetic only with Cramer's rule
172
173 TT a11r = T_ii(i, i) - wr;
174 TT a11i = wi;
175 // a12 and a21 are switched to transpose the system
176 TT a12 = T_ii(i + 1, i);
177 TT a21 = T_ii(i, i + 1);
178 TT a22r = T_ii(i + 1, i + 1) - wr;
179 TT a22i = wi;
180
181 TT b1r = X_ii(i, k);
182 TT b1i = X_ii(i, k + 1);
183 TT b2r = X_ii(i + 1, k);
184 TT b2i = X_ii(i + 1, k + 1);
185
186 TT detr = a11r * a22r - a11i * a22i - a12 * a21;
187 TT deti = a11r * a22i + a11i * a22r;
188
189 TT denom = detr * detr + deti * deti;
190
191 TT c1r = a22r * b1r - a22i * b1i - a12 * b2r;
192 TT c1i = a22r * b1i + a22i * b1r - a12 * b2i;
193 TT x1r = (c1r * detr + c1i * deti) / denom;
194 TT x1i = (c1i * detr - c1r * deti) / denom;
195
196 TT c2r = (a11r * b2r - a11i * b2i) - (a21 * b1r);
197 TT c2i = (a11r * b2i + a11i * b2r) - (a21 * b1i);
198 TT x2r = (c2r * detr + c2i * deti) / denom;
199 TT x2i = (c2i * detr - c2r * deti) / denom;
200
201 X_ii(i, k) = x1r;
202 X_ii(i, k + 1) = x1i;
203 X_ii(i + 1, k) = x2r;
204 X_ii(i + 1, k + 1) = x2i;
205
206 i += 2;
207 }
208 else {
209 // 1x1 block
210 for (idx_t j = 0; j < i; ++j) {
211 X_ii(i, k) -= T_ii(j, i) * X_ii(j, k);
212 X_ii(i, k + 1) -= T_ii(j, i) * X_ii(j, k + 1);
213 }
214
215 // Do the complex division:
216 // (v1_r[i] + i*v1_i[i]) / (T11(i, i) - (wr + i*wi))
217 // in real arithmetic only
218 TT a = X_ii(i, k);
219 TT b = X_ii(i, k + 1);
220 TT c = T_ii(i, i) - wr;
221 TT d = wi;
222 TT denom = c * c + d * d;
223 X_ii(i, k) = (a * c + b * d) / denom;
224 X_ii(i, k + 1) = (b * c - a * d) / denom;
225
226 i += 1;
227 }
228 }
229
230 k += 2;
231 }
232 else {
233 TT w = shifts[k];
234
235 if constexpr (is_complex<TT>) {
236 // The matrix is complex, so there are no two-by-two blocks
237 // to consider
238
239 for (idx_t i = 0; i < nb; ++i) {
240 for (idx_t j = 0; j < i; ++j) {
241 X_ii(i, k) -= conj(T_ii(j, i)) * X_ii(j, k);
242 }
243
244 X_ii(i, k) = X_ii(i, k) / conj(T_ii(i, i) - w);
245 }
246 }
247 else {
248 // The matrix is real, so we need to consider potential
249 // 2x2 blocks
250 // The matrix is real, so we need to consider potential
251 // 2x2 blocks for complex conjugate eigenvalue pairs
252 idx_t i = 0;
253 while (i < nb) {
254 bool is_2x2_block = false;
255 if (i + 1 < nb) {
256 if (T_ii(i + 1, i) != TT(0)) {
257 is_2x2_block = true;
258 }
259 }
260
261 if (is_2x2_block) {
262 // 2x2 block
263
264 for (idx_t j = 0; j < i; ++j) {
265 X_ii(i, k) -= T_ii(j, i) * X_ii(j, k);
266 X_ii(i + 1, k) -= T_ii(j, i + 1) * X_ii(j, k);
267 }
268
269 // Solve the 2x2 (transposed) system:
270 // [T33(i,i)-w T33(i+1,i) ] [v3[i] ] = [rhs1]
271 // [T33(i,i+1) T33(i+1,i+1)-w] [v3[i+1]] [rhs2]
272 TT rhs1 = X_ii(i, k);
273 TT rhs2 = X_ii(i + 1, k);
274
275 TT a = T_ii(i, i) - w;
276 TT b = T_ii(i + 1, i);
277 TT c = T_ii(i, i + 1);
278 TT d = T_ii(i + 1, i + 1) - w;
279
280 TT det = a * d - b * c;
281
282 X_ii(i, k) = (d * rhs1 - b * rhs2) / det;
283 X_ii(i + 1, k) = (-c * rhs1 + a * rhs2) / det;
284
285 i += 2;
286 }
287 else {
288 // 1x1 block
289 for (idx_t j = 0; j < i; ++j) {
290 X_ii(i, k) -= T_ii(j, i) * X_ii(j, k);
291 }
292
293 X_ii(i, k) = X_ii(i, k) / (T_ii(i, i) - w);
294
295 i += 1;
296 }
297 }
298 }
299
300 k += 1;
301 }
302 }
303
304 auto T_ij = slice(T, range(bs, be), range(be, n));
305 auto X_i = slice(X, range(be, n), range(0, nk));
307
308 iib += nb;
309 }
310}
311
312} // namespace tlapack
313
314#endif // TLAPACK_TREVC3_HH
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Sort the numbers in D in increasing order (if ID = 'I') or in decreasing order (if ID = 'D' ).
Definition arrayTraits.hpp:15
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
void trevc_forwardsolve_double(const matrix_T_t &T, vector_v_t &v_r, vector_v_t &v_i, const size_type< matrix_T_t > k, const vector_colN_t &colN)
Calculate the k-th left eigenvector pair of T using forward substitution.
Definition trevc_forwardsolve.hpp:371
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
void trevc_forwardsolve_single(const matrix_T_t &T, vector_v_t &v, const size_type< matrix_T_t > k, const vector_colN_t &colN)
Calculate the k-th left eigenvector of T using forward substitution.
Definition trevc_forwardsolve.hpp:73
@ NoTrans
no transpose
@ ConjTrans
conjugate transpose
@ General
0 <= i <= m, 0 <= j <= n.
void trevc3_forwardsolve(const matrix_T_t &T, matrix_X_t &X, vector_colN_t &colN, work_t &work, size_type< matrix_T_t > ks, size_type< matrix_T_t > ke, size_type< matrix_T_t > blocksize)
Calculate the ks-th through ke-th (not inclusive) left eigenvector of T using a blocked backsubstitut...
Definition trevc3_forwardsolve.hpp:31