<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trevc3_backsolve.hpp
Go to the documentation of this file.
1
3// based on A. Schwarz et al., "Scalable eigenvector computation for the
4// non-symmetric eigenvalue problem"
5//
6// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_TREVC3_BACKSOLVE_HH
13#define TLAPACK_TREVC3_BACKSOLVE_HH
14
16#include "tlapack/blas/asum.hpp"
17#include "tlapack/blas/gemm.hpp"
20
21namespace tlapack {
22
35 work_t& work,
39{
40 using idx_t = size_type<matrix_T_t>;
41 using TT = type_t<matrix_T_t>;
42 using real_t = real_type<TT>;
44
45 const idx_t n = nrows(T);
46
47 tlapack_check(ncols(T) == n);
48
49 idx_t nk = ke - ks;
50 auto [shifts, work2] = reshape(work, nk);
51
52 laset(Uplo::General, TT(0), TT(0), X);
53
54 // Step 1: Calculate the eigenvectors of the ks:ke submatrix
55 // using trevc_backsolve_single or trevc_backsolve_double
56 // and update the rest of X using a matrix-matrix multiplication
57 {
58 // Calculate the eigenvectors of the ks:ke block
59 auto Tii = slice(T, range(ks, ke), range(ks, ke));
60 auto X_ii = slice(X, range(ks, ke), range(0, nk));
61
62 for (idx_t k = 0; k < nk;) {
63 bool pair = false;
64 if constexpr (is_real<TT>) {
65 if (k + 1 < nk) {
66 if (Tii(k + 1, k) != TT(0)) {
67 pair = true;
68 }
69 }
70 }
71
72 if (pair) {
73 if constexpr (is_real<TT>) {
74 TT alpha = Tii(k, k);
75 TT beta = Tii(k, k + 1);
76 TT gamma = Tii(k + 1, k);
77
78 // real part of eigenvalue
79 TT wr = alpha;
80 // imaginary part of eigenvalue
81 TT wi = sqrt(abs(beta)) * sqrt(abs(gamma));
82
83 shifts[k] = wr;
84 shifts[k + 1] = wi;
85 // todo: remember that k and k+1 form a pair somehow, maybe
86 // a bool array?
87
88 // Complex conjugate pair
89 auto x1 = col(X_ii, k);
90 auto x2 = col(X_ii, k + 1);
91 auto colN_ii = slice(colN, range(ks, ke));
93 }
94 k += 2;
95 }
96 else {
97 shifts[k] = Tii(k, k);
98 // Real eigenvalue
99 auto x1 = col(X_ii, k);
100 auto colN_ii = slice(colN, range(ks, ke));
102 k += 1;
103 }
104 }
105
106 auto T_ij = slice(T, range(0, ks), range(ks, ke));
107 auto X_i = slice(X, range(0, ks), range(0, nk));
108 // This is where you might think to initialize X_i as -T_ij
109 // But the multiplication below takes care of that because it also
110 // takes the element 1 into account
112 }
113
114 // Step 2: Now keep propagating the updates upwards, but keep the individual
115 // shifts in mind
116 for (idx_t iib = 0; iib < ks;) {
117 idx_t nb = min(blocksize, ks - iib);
118 // Start of the block
119 idx_t bs = ks - iib - nb;
120 // End of the block
121 idx_t be = ks - iib;
122
123 // Make sure we don't split 2x2 blocks
124 if constexpr (is_real<TT>) {
125 if (bs > 0) {
126 // TODO: find a better way to check this so we don't
127 // always access other blocks of T
128 if (T(bs, bs - 1) != TT(0)) {
129 bs -= 1;
130 nb += 1;
131 }
132 }
133 }
134
135 auto T_ii = slice(T, range(bs, be), range(bs, be));
136 auto X_ii = slice(X, range(bs, be), range(0, nk));
137
138 for (idx_t k = 0; k < nk;) {
139 bool pair = false;
140 if constexpr (is_real<TT>) {
141 if (k + 1 < nk) {
142 if (T(ks + k + 1, ks + k) != TT(0)) {
143 pair = true;
144 }
145 }
146 }
147
148 if (pair) {
149 TT wr = shifts[k];
150 TT wi = shifts[k + 1];
151
152 for (idx_t ii = 0; ii < nb;) {
153 idx_t i = nb - 1 - ii;
154 bool is_2x2_block = false;
155 auto v1_r = col(X_ii, k);
156 auto v1_i = col(X_ii, k + 1);
157 if (i > 0) {
158 if (T_ii(i, i - 1) != TT(0)) {
159 is_2x2_block = true;
160 }
161 }
162
163 if (is_2x2_block) {
164 // 2x2 block
165
166 // Solve the complex 2x2 system:
167 // [T11(i-1,i-1)- (wr + i*wi) T11(i-1,i) ]
168 // [T11(i, i-1) T11(i, i)- (wr + i*wi)]
169 // *
170 // x
171 // =
172 // [v1_r[i-1] + i*v1_i[i-1]]
173 // [v1_r[i] + i*v1_i[i] ]
174 // Using real arithmetic only with Cramer's rule
175
176 TT a11r = T_ii(i - 1, i - 1) - wr;
177 TT a11i = -wi;
178 TT a12 = T_ii(i - 1, i);
179 TT a21 = T_ii(i, i - 1);
180 TT a22r = T_ii(i, i) - wr;
181 TT a22i = -wi;
182
183 TT b1r = v1_r[i - 1];
184 TT b1i = v1_i[i - 1];
185 TT b2r = v1_r[i];
186 TT b2i = v1_i[i];
187
188 TT detr = a11r * a22r - a11i * a22i - a12 * a21;
189 TT deti = a11r * a22i + a11i * a22r;
190
191 TT denom = detr * detr + deti * deti;
192
193 TT c1r = a22r * b1r - a22i * b1i - a12 * b2r;
194 TT c1i = a22r * b1i + a22i * b1r - a12 * b2i;
195 TT x1r = (c1r * detr + c1i * deti) / denom;
196 TT x1i = (c1i * detr - c1r * deti) / denom;
197
198 TT c2r = (a11r * b2r - a11i * b2i) - (a21 * b1r);
199 TT c2i = (a11r * b2i + a11i * b2r) - (a21 * b1i);
200 TT x2r = (c2r * detr + c2i * deti) / denom;
201 TT x2i = (c2i * detr - c2r * deti) / denom;
202
203 v1_r[i - 1] = x1r;
204 v1_i[i - 1] = x1i;
205 v1_r[i] = x2r;
206 v1_i[i] = x2i;
207
208 // Update the right-hand side
209 for (idx_t j = 0; j + 1 < i; ++j) {
210 // Real part
211 v1_r[j] -= T_ii(j, i - 1) * v1_r[i - 1];
212 v1_r[j] -= T_ii(j, i) * v1_r[i];
213
214 // Imaginary part
215 v1_i[j] -= T_ii(j, i - 1) * v1_i[i - 1];
216 v1_i[j] -= T_ii(j, i) * v1_i[i];
217 }
218
219 ii += 2;
220 }
221 else {
222 // 1x1 block
223
224 // Do the complex division:
225 // (v1_r[i] + i*v1_i[i]) / (T11(i, i) - (wr + i*wi))
226 // in real arithmetic only
227 TT a = v1_r[i];
228 TT b = v1_i[i];
229 TT c = T_ii(i, i) - wr;
230 TT d = -wi;
231 TT denom = c * c + d * d;
232 v1_r[i] = (a * c + b * d) / denom;
233 v1_i[i] = (b * c - a * d) / denom;
234
235 // Update the right-hand side
236 for (idx_t j = 0; j < i; ++j) {
237 v1_r[j] -= T_ii(j, i) * v1_r[i];
238 v1_i[j] -= T_ii(j, i) * v1_i[i];
239 }
240
241 ii += 1;
242 }
243 }
244
245 k += 2;
246 }
247 else {
248 TT w = shifts[k];
249
250 if constexpr (is_complex<TT>) {
251 // The matrix is complex, so there are no two-by-two blocks
252 // to consider
253
254 for (idx_t ii = 0; ii < nb; ++ii) {
255 idx_t i = nb - 1 - ii;
256
257 X_ii(i, k) = X_ii(i, k) / (T_ii(i, i) - w);
258
259 for (idx_t j = 0; j < i; ++j) {
260 X_ii(j, k) -= T_ii(j, i) * X_ii(i, k);
261 }
262 }
263 }
264 else {
265 // The matrix is real, so we need to consider potential
266 // 2x2 blocks
267
268 for (idx_t ii = 0; ii < nb;) {
269 idx_t i = nb - 1 - ii;
270 bool is_2x2_block = false;
271 if (i > 0) {
272 if (T_ii(i, i - 1) != TT(0)) {
273 is_2x2_block = true;
274 }
275 }
276
277 if (is_2x2_block) {
278 // 2x2 block
279 // Solve the 2x2 system:
280 // [T_ii(i-1,i-1)-w T_ii(i-1,i) ] [v1[i-1]] =
281 // [rhs1] [T_ii(i, i-1) T_ii(i, i)-w ] [v1[i]
282 // ] [rhs2]
283 TT rhs1 = X_ii(i - 1, k);
284 TT rhs2 = X_ii(i, k);
285
286 TT a = T_ii(i - 1, i - 1) - w;
287 TT b = T_ii(i - 1, i);
288 TT c = T_ii(i, i - 1);
289 TT d = T_ii(i, i) - w;
290
291 TT det = a * d - b * c;
292
293 X_ii(i - 1, k) = (d * rhs1 - b * rhs2) / det;
294 X_ii(i, k) = (-c * rhs1 + a * rhs2) / det;
295
296 for (idx_t j = 0; j + 1 < i; ++j) {
297 X_ii(j, k) -= T_ii(j, i - 1) * X_ii(i - 1, k);
298 X_ii(j, k) -= T_ii(j, i) * X_ii(i, k);
299 }
300
301 ii += 2;
302 }
303 else {
304 // 1x1 block
305 X_ii(i, k) = X_ii(i, k) / (T_ii(i, i) - w);
306
307 for (idx_t j = 0; j < i; ++j) {
308 X_ii(j, k) -= T_ii(j, i) * X_ii(i, k);
309 }
310
311 ii += 1;
312 }
313 }
314 }
315
316 k += 1;
317 }
318 }
319
320 auto T_ij = slice(T, range(0, bs), range(bs, be));
321 auto X_i = slice(X, range(0, bs), range(0, nk));
323
324 iib += nb;
325 }
326}
327
328} // namespace tlapack
329
330#endif // TLAPACK_TREVC3_HH
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Sort the numbers in D in increasing order (if ID = 'I') or in decreasing order (if ID = 'D' ).
Definition arrayTraits.hpp:15
void trevc_backsolve_double(const matrix_T_t &T, vector_v_t &v_r, vector_v_t &v_i, const size_type< matrix_T_t > k, const vector_colN_t &colN)
Calculate the k-th right eigenvector pair of T using backsubstitution.
Definition trevc_backsolve.hpp:415
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
@ NoTrans
no transpose
void trevc3_backsolve(const matrix_T_t &T, matrix_X_t &X, vector_colN_t &colN, work_t &work, size_type< matrix_T_t > ks, size_type< matrix_T_t > ke, size_type< matrix_T_t > blocksize)
Calculate the ks-th through ke-th (not inclusive) right eigenvector of T using a blocked backsubstitu...
Definition trevc3_backsolve.hpp:32
@ General
0 <= i <= m, 0 <= j <= n.
void trevc_backsolve_single(const matrix_T_t &T, vector_v_t &v, const size_type< matrix_T_t > k, const vector_colN_t &colN)
Calculate the k-th right eigenvector of T using backsubstitution.
Definition trevc_backsolve.hpp:79