<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
schur_swap.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_SCHUR_SWAP_HH
13#define TLAPACK_SCHUR_SWAP_HH
14
16#include "tlapack/blas/rot.hpp"
17#include "tlapack/blas/rotg.hpp"
18#include "tlapack/blas/swap.hpp"
23
24namespace tlapack {
25
47template <TLAPACK_CSMATRIX matrix_t,
48 enable_if_t<is_real<type_t<matrix_t>>, bool> = true>
50 matrix_t& A,
51 matrix_t& Q,
55{
56 using idx_t = size_type<matrix_t>;
57 using T = type_t<matrix_t>;
59
60 // Functor for creating new matrices
64
65 const idx_t n = ncols(A);
66 const T zero(0);
67 const T ten(10);
68
69 tlapack_check(nrows(A) == n);
70 tlapack_check(nrows(Q) == n);
71 tlapack_check(ncols(Q) == n);
72 tlapack_check(0 <= j0);
73 tlapack_check(j0 + n1 + n2 <= n);
74 tlapack_check(n1 == 1 or n1 == 2);
75 tlapack_check(n2 == 1 or n2 == 2);
76
77 const idx_t j1 = j0 + 1;
78 const idx_t j2 = j0 + 2;
79 const idx_t j3 = j0 + 3;
80
81 // Check if the 2x2 eigenvalue blocks consist of 2 1x1 blocks
82 // If so, treat them separately
83 if (n1 == 2)
84 if (A(j1, j0) == zero) {
85 // only 2x2 swaps can fail, so we don't need to check for error
86 schur_swap(want_q, A, Q, j1, (idx_t)1, n2);
87 schur_swap(want_q, A, Q, j0, (idx_t)1, n2);
88 return 0;
89 }
90 if (n2 == 2)
91 if (A(j0 + n1 + 1, j0 + n1) == zero) {
92 // only 2x2 swaps can fail, so we don't need to check for error
93 schur_swap(want_q, A, Q, j0, n1, (idx_t)1);
94 schur_swap(want_q, A, Q, j1, n1, (idx_t)1);
95 return 0;
96 }
97
98 if (n1 == 1 and n2 == 1) {
99 //
100 // Swap two 1-by-1 blocks.
101 //
102 const T t00 = A(j0, j0);
103 const T t11 = A(j1, j1);
104 //
105 // Determine the transformation to perform the interchange
106 //
107 T cs, sn;
108 T temp = A(j0, j1);
109 T temp2 = t11 - t00;
110 rotg(temp, temp2, cs, sn);
111
112 A(j1, j1) = t00;
113 A(j0, j0) = t11;
114
115 // Apply transformation from the left
116 if (j2 < n) {
117 auto row1 = slice(A, j0, range{j2, n});
118 auto row2 = slice(A, j1, range{j2, n});
119 rot(row1, row2, cs, sn);
120 }
121 // Apply transformation from the right
122 if (j0 > 0) {
123 auto col1 = slice(A, range{0, j0}, j0);
124 auto col2 = slice(A, range{0, j0}, j1);
125 rot(col1, col2, cs, sn);
126 }
127 if (want_q) {
128 auto row1 = col(Q, j0);
129 auto row2 = col(Q, j1);
130 rot(row1, row2, cs, sn);
131 }
132 }
133 if (n1 == 1 and n2 == 2) {
134 //
135 // Swap 1-by-1 block with 2-by-2 block
136 //
137
138 T B_[3 * 2];
139 auto B = new_3by2_matrix(B_);
140 B(0, 0) = A(j0, j1);
141 B(1, 0) = A(j1, j1) - A(j0, j0);
142 B(2, 0) = A(j2, j1);
143 B(0, 1) = A(j0, j2);
144 B(1, 1) = A(j1, j2);
145 B(2, 1) = A(j2, j2) - A(j0, j0);
146
147 // Make B upper triangular
148 T tau1, tau2;
149 auto v1 = slice(B, range{0, 3}, 0);
150 auto v2 = slice(B, range{1, 3}, 1);
152 const T sum = B(0, 1) + v1[1] * B(1, 1) + v1[2] * B(2, 1);
153 B(0, 1) = B(0, 1) - sum * tau1;
154 B(1, 1) = B(1, 1) - sum * tau1 * v1[1];
155 B(2, 1) = B(2, 1) - sum * tau1 * v1[2];
157
158 //
159 // Apply reflections to A and Q
160 //
161
162 // Reflections from the left
163 for (idx_t j = j0; j < n; ++j) {
164 T sum = A(j0, j) + v1[1] * A(j1, j) + v1[2] * A(j2, j);
165 A(j0, j) = A(j0, j) - sum * tau1;
166 A(j1, j) = A(j1, j) - sum * tau1 * v1[1];
167 A(j2, j) = A(j2, j) - sum * tau1 * v1[2];
168
169 sum = A(j1, j) + v2[1] * A(j2, j);
170 A(j1, j) = A(j1, j) - sum * tau2;
171 A(j2, j) = A(j2, j) - sum * tau2 * v2[1];
172 }
173 // Reflections from the right
174 for (idx_t j = 0; j < j3; ++j) {
175 T sum = A(j, j0) + v1[1] * A(j, j1) + v1[2] * A(j, j2);
176 A(j, j0) = A(j, j0) - sum * tau1;
177 A(j, j1) = A(j, j1) - sum * tau1 * v1[1];
178 A(j, j2) = A(j, j2) - sum * tau1 * v1[2];
179
180 sum = A(j, j1) + v2[1] * A(j, j2);
181 A(j, j1) = A(j, j1) - sum * tau2;
182 A(j, j2) = A(j, j2) - sum * tau2 * v2[1];
183 }
184
185 if (want_q) {
186 for (idx_t j = 0; j < n; ++j) {
187 T sum = Q(j, j0) + v1[1] * Q(j, j1) + v1[2] * Q(j, j2);
188 Q(j, j0) = Q(j, j0) - sum * tau1;
189 Q(j, j1) = Q(j, j1) - sum * tau1 * v1[1];
190 Q(j, j2) = Q(j, j2) - sum * tau1 * v1[2];
191
192 sum = Q(j, j1) + v2[1] * Q(j, j2);
193 Q(j, j1) = Q(j, j1) - sum * tau2;
194 Q(j, j2) = Q(j, j2) - sum * tau2 * v2[1];
195 }
196 }
197
198 A(j2, j0) = zero;
199 A(j2, j1) = zero;
200 }
201 if (n1 == 2 and n2 == 1) {
202 //
203 // Swap 2-by-2 block with 1-by-1 block
204 //
205
206 T B_[3 * 2];
207 auto B = new_3by2_matrix(B_);
208 B(0, 0) = A(j1, j2);
209 B(1, 0) = A(j1, j1) - A(j2, j2);
210 B(2, 0) = A(j1, j0);
211 B(0, 1) = A(j0, j2);
212 B(1, 1) = A(j0, j1);
213 B(2, 1) = A(j0, j0) - A(j2, j2);
214
215 // Make B upper triangular
216 T tau1, tau2;
217 auto v1 = slice(B, range{0, 3}, 0);
218 auto v2 = slice(B, range{1, 3}, 1);
220 const T sum = B(0, 1) + v1[1] * B(1, 1) + v1[2] * B(2, 1);
221 B(0, 1) = B(0, 1) - sum * tau1;
222 B(1, 1) = B(1, 1) - sum * tau1 * v1[1];
223 B(2, 1) = B(2, 1) - sum * tau1 * v1[2];
225
226 //
227 // Apply reflections to A and Q
228 //
229
230 // Reflections from the left
231 for (idx_t j = j0; j < n; ++j) {
232 T sum = A(j2, j) + v1[1] * A(j1, j) + v1[2] * A(j0, j);
233 A(j2, j) = A(j2, j) - sum * tau1;
234 A(j1, j) = A(j1, j) - sum * tau1 * v1[1];
235 A(j0, j) = A(j0, j) - sum * tau1 * v1[2];
236
237 sum = A(j1, j) + v2[1] * A(j0, j);
238 A(j1, j) = A(j1, j) - sum * tau2;
239 A(j0, j) = A(j0, j) - sum * tau2 * v2[1];
240 }
241 // Reflections from the right
242 for (idx_t j = 0; j < j3; ++j) {
243 T sum = A(j, j2) + v1[1] * A(j, j1) + v1[2] * A(j, j0);
244 A(j, j2) = A(j, j2) - sum * tau1;
245 A(j, j1) = A(j, j1) - sum * tau1 * v1[1];
246 A(j, j0) = A(j, j0) - sum * tau1 * v1[2];
247
248 sum = A(j, j1) + v2[1] * A(j, j0);
249 A(j, j1) = A(j, j1) - sum * tau2;
250 A(j, j0) = A(j, j0) - sum * tau2 * v2[1];
251 }
252
253 if (want_q) {
254 for (idx_t j = 0; j < n; ++j) {
255 T sum = Q(j, j2) + v1[1] * Q(j, j1) + v1[2] * Q(j, j0);
256 Q(j, j2) = Q(j, j2) - sum * tau1;
257 Q(j, j1) = Q(j, j1) - sum * tau1 * v1[1];
258 Q(j, j0) = Q(j, j0) - sum * tau1 * v1[2];
259
260 sum = Q(j, j1) + v2[1] * Q(j, j0);
261 Q(j, j1) = Q(j, j1) - sum * tau2;
262 Q(j, j0) = Q(j, j0) - sum * tau2 * v2[1];
263 }
264 }
265
266 A(j1, j0) = zero;
267 A(j2, j0) = zero;
268 }
269 if (n1 == 2 and n2 == 2) {
270 T D_[4 * 4];
271 auto D = new_4by4_matrix(D_);
272
273 auto AD_slice = slice(A, range{j0, j0 + 4}, range{j0, j0 + 4});
275 auto dnorm = lange(MAX_NORM, D);
276
277 const T eps = ulp<T>();
278 const T small_num = safe_min<T>() / eps;
279 T thresh = max(ten * eps * dnorm, small_num);
280 // Note: max() may not propagate NaNs.
281
282 T V_[4 * 2];
283 auto V = new_4by2_matrix(V_);
284 auto X = slice(V, range{0, 2}, range{0, 2});
285 auto TL = slice(D, range{0, 2}, range{0, 2});
286 auto TR = slice(D, range{2, 4}, range{2, 4});
287 auto B = slice(D, range{0, 2}, range{2, 4});
288 T scale, xnorm;
289 lasy2(NO_TRANS, NO_TRANS, -1, TL, TR, B, scale, X, xnorm);
290
291 V(2, 0) = -scale;
292 V(2, 1) = zero;
293 V(3, 0) = zero;
294 V(3, 1) = -scale;
295
296 // Make V upper triangular
297 T tau1, tau2;
298 auto v1 = slice(V, range{0, 4}, 0);
299 auto v2 = slice(V, range{1, 4}, 1);
301 const T sum =
302 V(0, 1) + v1[1] * V(1, 1) + v1[2] * V(2, 1) + v1[3] * V(3, 1);
303 V(0, 1) = V(0, 1) - sum * tau1;
304 V(1, 1) = V(1, 1) - sum * tau1 * v1[1];
305 V(2, 1) = V(2, 1) - sum * tau1 * v1[2];
306 V(3, 1) = V(3, 1) - sum * tau1 * v1[3];
308
309 // Apply reflections to D to check error
310 for (idx_t j = 0; j < 4; ++j) {
311 T sum =
312 D(0, j) + v1[1] * D(1, j) + v1[2] * D(2, j) + v1[3] * D(3, j);
313 D(0, j) = D(0, j) - sum * tau1;
314 D(1, j) = D(1, j) - sum * tau1 * v1[1];
315 D(2, j) = D(2, j) - sum * tau1 * v1[2];
316 D(3, j) = D(3, j) - sum * tau1 * v1[3];
317
318 sum = D(1, j) + v2[1] * D(2, j) + v2[2] * D(3, j);
319 D(1, j) = D(1, j) - sum * tau2;
320 D(2, j) = D(2, j) - sum * tau2 * v2[1];
321 D(3, j) = D(3, j) - sum * tau2 * v2[2];
322 }
323 for (idx_t j = 0; j < 4; ++j) {
324 T sum =
325 D(j, 0) + v1[1] * D(j, 1) + v1[2] * D(j, 2) + v1[3] * D(j, 3);
326 D(j, 0) = D(j, 0) - sum * tau1;
327 D(j, 1) = D(j, 1) - sum * tau1 * v1[1];
328 D(j, 2) = D(j, 2) - sum * tau1 * v1[2];
329 D(j, 3) = D(j, 3) - sum * tau1 * v1[3];
330
331 sum = D(j, 1) + v2[1] * D(j, 2) + v2[2] * D(j, 3);
332 D(j, 1) = D(j, 1) - sum * tau2;
333 D(j, 2) = D(j, 2) - sum * tau2 * v2[1];
334 D(j, 3) = D(j, 3) - sum * tau2 * v2[2];
335 }
336
337 if (max(max(abs(D(2, 0)), abs(D(2, 1))),
338 max(abs(D(3, 0)), abs(D(3, 1)))) > thresh)
339 return 1;
340
341 // Reflections from the left
342 for (idx_t j = j0; j < n; ++j) {
343 T sum = A(j0, j) + v1[1] * A(j1, j) + v1[2] * A(j2, j) +
344 v1[3] * A(j3, j);
345 A(j0, j) = A(j0, j) - sum * tau1;
346 A(j1, j) = A(j1, j) - sum * tau1 * v1[1];
347 A(j2, j) = A(j2, j) - sum * tau1 * v1[2];
348 A(j3, j) = A(j3, j) - sum * tau1 * v1[3];
349
350 sum = A(j1, j) + v2[1] * A(j2, j) + v2[2] * A(j3, j);
351 A(j1, j) = A(j1, j) - sum * tau2;
352 A(j2, j) = A(j2, j) - sum * tau2 * v2[1];
353 A(j3, j) = A(j3, j) - sum * tau2 * v2[2];
354 }
355 // Reflections from the right
356 for (idx_t j = 0; j < j0 + 4; ++j) {
357 T sum = A(j, j0) + v1[1] * A(j, j1) + v1[2] * A(j, j2) +
358 v1[3] * A(j, j3);
359 A(j, j0) = A(j, j0) - sum * tau1;
360 A(j, j1) = A(j, j1) - sum * tau1 * v1[1];
361 A(j, j2) = A(j, j2) - sum * tau1 * v1[2];
362 A(j, j3) = A(j, j3) - sum * tau1 * v1[3];
363
364 sum = A(j, j1) + v2[1] * A(j, j2) + v2[2] * A(j, j3);
365 A(j, j1) = A(j, j1) - sum * tau2;
366 A(j, j2) = A(j, j2) - sum * tau2 * v2[1];
367 A(j, j3) = A(j, j3) - sum * tau2 * v2[2];
368 }
369
370 if (want_q) {
371 for (idx_t j = 0; j < n; ++j) {
372 T sum = Q(j, j0) + v1[1] * Q(j, j1) + v1[2] * Q(j, j2) +
373 v1[3] * Q(j, j3);
374 Q(j, j0) = Q(j, j0) - sum * tau1;
375 Q(j, j1) = Q(j, j1) - sum * tau1 * v1[1];
376 Q(j, j2) = Q(j, j2) - sum * tau1 * v1[2];
377 Q(j, j3) = Q(j, j3) - sum * tau1 * v1[3];
378
379 sum = Q(j, j1) + v2[1] * Q(j, j2) + v2[2] * Q(j, j3);
380 Q(j, j1) = Q(j, j1) - sum * tau2;
381 Q(j, j2) = Q(j, j2) - sum * tau2 * v2[1];
382 Q(j, j3) = Q(j, j3) - sum * tau2 * v2[2];
383 }
384 }
385
386 A(j2, j0) = zero;
387 A(j2, j1) = zero;
388 A(j3, j0) = zero;
389 A(j3, j1) = zero;
390 }
391
392 // Standardize the 2x2 Schur blocks (if any)
393 if (n2 == 2) {
394 T cs, sn;
396 lahqr_schur22(A(j0, j0), A(j0, j1), A(j1, j0), A(j1, j1), s1, s2, cs,
397 sn); // Apply transformation from the left
398 if (j2 < n) {
399 auto row1 = slice(A, j0, range{j2, n});
400 auto row2 = slice(A, j1, range{j2, n});
401 rot(row1, row2, cs, sn);
402 }
403 // Apply transformation from the right
404 if (j0 > 0) {
405 auto col1 = slice(A, range{0, j0}, j0);
406 auto col2 = slice(A, range{0, j0}, j1);
407 rot(col1, col2, cs, sn);
408 }
409 if (want_q) {
410 auto row1 = col(Q, j0);
411 auto row2 = col(Q, j1);
412 rot(row1, row2, cs, sn);
413 }
414 }
415 if (n1 == 2) {
416 idx_t j0_2 = j0 + n2;
417 idx_t j1_2 = j0_2 + 1;
418
419 T cs, sn;
422 A(j1_2, j1_2), s1, s2, cs,
423 sn); // Apply transformation from the left
424 if (j0_2 + 2 < n) {
425 auto row1 = slice(A, j0_2, range{j0_2 + 2, n});
426 auto row2 = slice(A, j1_2, range{j0_2 + 2, n});
427 rot(row1, row2, cs, sn);
428 }
429 // Apply transformation from the right
430 if (j0_2 > 0) {
431 auto col1 = slice(A, range{0, j0_2}, j0_2);
432 auto col2 = slice(A, range{0, j0_2}, j1_2);
433 rot(col1, col2, cs, sn);
434 }
435 if (want_q) {
436 auto row1 = col(Q, j0_2);
437 auto row2 = col(Q, j1_2);
438 rot(row1, row2, cs, sn);
439 }
440 }
441
442 return 0;
443}
444
451template <TLAPACK_CSMATRIX matrix_t,
452 enable_if_t<is_complex<type_t<matrix_t>>, bool> = true>
453int schur_swap(bool want_q,
454 matrix_t& A,
455 matrix_t& Q,
456 const size_type<matrix_t>& j0,
457 const size_type<matrix_t>& n1,
458 const size_type<matrix_t>& n2)
459{
460 using idx_t = size_type<matrix_t>;
461 using T = type_t<matrix_t>;
462 using real_t = real_type<T>;
463 using range = pair<idx_t, idx_t>;
464
465 const idx_t n = ncols(A);
466
467 tlapack_check(nrows(A) == n);
468 tlapack_check(nrows(Q) == n);
469 tlapack_check(ncols(Q) == n);
470 tlapack_check(0 <= j0 and j0 < n);
471 tlapack_check(n1 == 1);
472 tlapack_check(n2 == 1);
473
474 const idx_t j1 = j0 + 1;
475 const idx_t j2 = j0 + 2;
476
477 //
478 // In the complex case, there can only be 1x1 blocks to swap
479 //
480 const T t00 = A(j0, j0);
481 const T t11 = A(j1, j1);
482 //
483 // Determine the transformation to perform the interchange
484 //
485 real_t cs;
486 T sn;
487 T temp = A(j0, j1);
488 T temp2 = t11 - t00;
489 rotg(temp, temp2, cs, sn);
490
491 A(j1, j1) = t00;
492 A(j0, j0) = t11;
493
494 // Apply transformation from the left
495 if (j2 < n) {
496 auto row1 = slice(A, j0, range{j2, n});
497 auto row2 = slice(A, j1, range{j2, n});
498 rot(row1, row2, cs, sn);
499 }
500 // Apply transformation from the right
501 if (j0 > 0) {
502 auto col1 = slice(A, range{0, j0}, j0);
503 auto col2 = slice(A, range{0, j0}, j1);
504 rot(col1, col2, cs, conj(sn));
505 }
506 if (want_q) {
507 auto row1 = col(Q, j0);
508 auto row2 = col(Q, j1);
509 rot(row1, row2, cs, conj(sn));
510 }
511
512 return 0;
513}
514
515} // namespace tlapack
516
517#endif // TLAPACK_SCHUR_SWAP_HH
constexpr internal::MaxNorm MAX_NORM
max norm
Definition types.hpp:334
constexpr internal::Forward FORWARD
Forward direction.
Definition types.hpp:376
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::ColumnwiseStorage COLUMNWISE_STORAGE
Columnwise storage.
Definition types.hpp:409
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_CSMATRIX
Macro for tlapack::concepts::ConstructableAndSliceableMatrix compatible with C++17.
Definition concepts.hpp:961
auto lange(norm_t normType, const matrix_t &A)
Calculates the norm of a matrix.
Definition lange.hpp:38
int schur_swap(bool want_q, matrix_t &A, matrix_t &Q, const size_type< matrix_t > &j0, const size_type< matrix_t > &n1, const size_type< matrix_t > &n2)
schur_swap, swaps 2 eigenvalues of A.
Definition schur_swap.hpp:49
void larfg(storage_t storeMode, type_t< vector_t > &alpha, vector_t &x, type_t< vector_t > &tau)
Generates a elementary Householder reflection.
Definition larfg.hpp:73
void lahqr_schur22(T &a, T &b, T &c, T &d, complex_type< T > &s1, complex_type< T > &s2, T &cs, T &sn)
Computes the Schur factorization of a 2x2 matrix A.
Definition lahqr_schur22.hpp:44
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
int lasy2(Op trans_l, Op trans_r, int isign, const matrixT_t &TL, const matrixT_t &TR, const matrixB_t &B, type_t< matrixX_t > &scale, matrixX_t &X, type_t< matrixX_t > &xnorm)
lasy2 solves the Sylvester matrix equation where the matrices are of order 1 or 2.
Definition lasy2.hpp:42
void rotg(T &a, T &b, T &c, T &s)
Construct plane rotation that eliminates b, such that:
Definition rotg.hpp:39
void rot(vectorX_t &x, vectorY_t &y, const c_type &c, const s_type &s)
Apply plane rotation:
Definition rot.hpp:44
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113