<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
generalized_schur_swap.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_GENERALIZED_SCHUR_SWAP_HH
13#define TLAPACK_GENERALIZED_SCHUR_SWAP_HH
14
16#include "tlapack/blas/rot.hpp"
17#include "tlapack/blas/rotg.hpp"
18#include "tlapack/blas/swap.hpp"
19#include "tlapack/blas/trsv.hpp"
26
27namespace tlapack {
28
56template <TLAPACK_CSMATRIX matrix_t,
57 enable_if_t<is_real<type_t<matrix_t>>, bool> = true>
59 bool want_z,
60 matrix_t& A,
61 matrix_t& B,
62 matrix_t& Q,
63 matrix_t& Z,
67{
68 using idx_t = size_type<matrix_t>;
69 using T = type_t<matrix_t>;
71
72 // Functor for creating new matrices
74
75 const idx_t n = ncols(A);
76 const T zero(0);
77
78 tlapack_check(nrows(A) == n);
79 tlapack_check(nrows(Q) == n);
80 tlapack_check(ncols(Q) == n);
81 tlapack_check(0 <= j0);
82 tlapack_check(j0 + n1 + n2 <= n);
83 tlapack_check(n1 == 1 or n1 == 2);
84 tlapack_check(n2 == 1 or n2 == 2);
85
86 const idx_t j1 = j0 + 1;
87 const idx_t j2 = j0 + 2;
88 const idx_t j3 = j0 + 3;
89
90 // Check if the 2x2 eigenvalue blocks consist of 2 1x1 blocks
91 // If so, treat them separately
92 if (n1 == 2)
93 if (A(j1, j0) == zero) {
94 // only 2x2 swaps can fail, so we don't need to check for error
95 generalized_schur_swap(want_q, want_z, A, B, Q, Z, j1, (idx_t)1,
96 n2);
97 generalized_schur_swap(want_q, want_z, A, B, Q, Z, j0, (idx_t)1,
98 n2);
99 return 0;
100 }
101 if (n2 == 2)
102 if (A(j0 + n1 + 1, j0 + n1) == zero) {
103 // only 2x2 swaps can fail, so we don't need to check for error
105 (idx_t)1);
107 (idx_t)1);
108 return 0;
109 }
110
111 if (n1 == 1 and n2 == 1) {
112 //
113 // Swap two 1-by-1 blocks.
114 //
115 const T a00 = A(j0, j0);
116 const T a01 = A(j0, j1);
117 const T a11 = A(j1, j1);
118 const T b00 = B(j0, j0);
119 const T b01 = B(j0, j1);
120 const T b11 = B(j1, j1);
121
122 const bool use_b = abs(b11 * a00) > abs(b00 * a11);
123
124 //
125 // Determine the transformation to perform the interchange
126 //
127 T cl, sl, cr, sr;
128 T temp = b11 * a00 - a11 * b00;
129 T temp2 = b11 * a01 - a11 * b01;
130 rotg(temp2, temp, cr, sr);
131
132 // Apply transformation from the right
133 {
134 auto a1 = slice(A, range{0, j1 + 1}, j0);
135 auto a2 = slice(A, range{0, j1 + 1}, j1);
136 rot(a2, a1, cr, sr);
137 auto b1 = slice(B, range{0, j1 + 1}, j0);
138 auto b2 = slice(B, range{0, j1 + 1}, j1);
139 rot(b2, b1, cr, sr);
140 if (want_z) {
141 auto z1 = col(Z, j0);
142 auto z2 = col(Z, j1);
143 rot(z2, z1, cr, sr);
144 }
145 }
146
147 if (use_b) {
148 temp = B(j0, j0);
149 temp2 = B(j1, j0);
150 }
151 else {
152 temp = A(j0, j0);
153 temp2 = A(j1, j0);
154 }
155 rotg(temp, temp2, cl, sl);
156
157 // Apply transformation from the left
158 {
159 auto a1 = slice(A, j0, range{j0, n});
160 auto a2 = slice(A, j1, range{j0, n});
161 rot(a1, a2, cl, sl);
162 auto b1 = slice(B, j0, range{j0, n});
163 auto b2 = slice(B, j1, range{j0, n});
164 rot(b1, b2, cl, sl);
165 if (want_q) {
166 auto q1 = col(Q, j0);
167 auto q2 = col(Q, j1);
168 rot(q1, q2, cl, sl);
169 }
170 }
171
172 A(j1, j0) = (T)0;
173 B(j1, j0) = (T)0;
174 }
175 if (n1 == 1 and n2 == 2) {
176 //
177 // Swap 1-by-1 block with 2-by-2 block
178 //
179
180 std::vector<T> H_;
181 auto H = new_matrix(H_, 2, 3);
182 std::vector<T> v(3);
183
185 T beta1, beta2;
186
187 auto A1 = slice(A, range(j1, j3), range(j1, j3));
188 auto B1 = slice(B, range(j1, j3), range(j1, j3));
190 auto a00 = A(j0, j0);
191 auto b00 = B(j0, j0);
192
193 bool use_b = abs(b00 * alpha1) > abs(beta1 * a00);
194
195 H(0, 0) = b00 * A(j2, j1) - a00 * B(j2, j1);
196 H(0, 1) = b00 * A(j1, j1) - a00 * B(j1, j1);
197 H(0, 2) = b00 * A(j0, j1) - a00 * B(j0, j1);
198 H(1, 0) = b00 * A(j2, j2) - a00 * B(j2, j2);
199 H(1, 1) = b00 * A(j1, j2) - a00 * B(j1, j2);
200 H(1, 2) = b00 * A(j0, j2) - a00 * B(j0, j2);
201
202 T tau;
203 inv_house3(H, v, tau);
204
205 // Apply update from the left
206 for (idx_t j = j0; j < n; ++j) {
207 T sum = A(j2, j) + v[1] * A(j1, j) + v[2] * A(j0, j);
208 A(j2, j) = A(j2, j) - sum * tau;
209 A(j1, j) = A(j1, j) - sum * tau * v[1];
210 A(j0, j) = A(j0, j) - sum * tau * v[2];
211 }
212 for (idx_t j = j0; j < n; ++j) {
213 T sum = B(j2, j) + v[1] * B(j1, j) + v[2] * B(j0, j);
214 B(j2, j) = B(j2, j) - sum * tau;
215 B(j1, j) = B(j1, j) - sum * tau * v[1];
216 B(j0, j) = B(j0, j) - sum * tau * v[2];
217 }
218 for (idx_t j = 0; j < n; ++j) {
219 T sum = Q(j, j2) + v[1] * Q(j, j1) + v[2] * Q(j, j0);
220 Q(j, j2) = Q(j, j2) - sum * tau;
221 Q(j, j1) = Q(j, j1) - sum * tau * v[1];
222 Q(j, j0) = Q(j, j0) - sum * tau * v[2];
223 }
224
225 if (use_b) {
226 v[0] = B(j2, j2);
227 v[1] = B(j2, j1);
228 v[2] = B(j2, j0);
229 }
230 else {
231 v[0] = A(j2, j2);
232 v[1] = A(j2, j1);
233 v[2] = A(j2, j0);
234 }
235
237
238 // Apply update from the right
239 for (idx_t j = 0; j < j3; ++j) {
240 T sum = A(j, j2) + v[1] * A(j, j1) + v[2] * A(j, j0);
241 A(j, j2) = A(j, j2) - sum * tau;
242 A(j, j1) = A(j, j1) - sum * tau * v[1];
243 A(j, j0) = A(j, j0) - sum * tau * v[2];
244 }
245 for (idx_t j = 0; j < j3; ++j) {
246 T sum = B(j, j2) + v[1] * B(j, j1) + v[2] * B(j, j0);
247 B(j, j2) = B(j, j2) - sum * tau;
248 B(j, j1) = B(j, j1) - sum * tau * v[1];
249 B(j, j0) = B(j, j0) - sum * tau * v[2];
250 }
251 for (idx_t j = 0; j < n; ++j) {
252 T sum = Z(j, j2) + v[1] * Z(j, j1) + v[2] * Z(j, j0);
253 Z(j, j2) = Z(j, j2) - sum * tau;
254 Z(j, j1) = Z(j, j1) - sum * tau * v[1];
255 Z(j, j0) = Z(j, j0) - sum * tau * v[2];
256 }
257
258 A(j2, j0) = (T)0;
259 A(j2, j1) = (T)0;
260 B(j2, j0) = (T)0;
261 B(j2, j1) = (T)0;
262 }
263 if (n1 == 2 and n2 == 1) {
264 //
265 // Swap 2-by-2 block with 1-by-1 block
266 //
267 std::vector<T> H_;
268 auto H = new_matrix(H_, 2, 3);
269 std::vector<T> v(3);
270
272 T beta1, beta2;
273
274 auto A1 = slice(A, range(j0, j2), range(j0, j2));
275 auto B1 = slice(B, range(j0, j2), range(j0, j2));
277 auto a22 = A(j2, j2);
278 auto b22 = B(j2, j2);
279
280 bool use_b = abs(b22 * alpha1) > abs(beta1 * a22);
281
282 H(0, 0) = b22 * A(j0, j0) - a22 * B(j0, j0);
283 H(0, 1) = b22 * A(j0, j1) - a22 * B(j0, j1);
284 H(0, 2) = b22 * A(j0, j2) - a22 * B(j0, j2);
285 H(1, 0) = b22 * A(j1, j0) - a22 * B(j1, j0);
286 H(1, 1) = b22 * A(j1, j1) - a22 * B(j1, j1);
287 H(1, 2) = b22 * A(j1, j2) - a22 * B(j1, j2);
288
289 T tau;
290 inv_house3(H, v, tau);
291
292 // Apply update from the right
293 for (idx_t j = 0; j < j3; ++j) {
294 T sum = A(j, j0) + v[1] * A(j, j1) + v[2] * A(j, j2);
295 A(j, j0) = A(j, j0) - sum * tau;
296 A(j, j1) = A(j, j1) - sum * tau * v[1];
297 A(j, j2) = A(j, j2) - sum * tau * v[2];
298 }
299 for (idx_t j = 0; j < j3; ++j) {
300 T sum = B(j, j0) + v[1] * B(j, j1) + v[2] * B(j, j2);
301 B(j, j0) = B(j, j0) - sum * tau;
302 B(j, j1) = B(j, j1) - sum * tau * v[1];
303 B(j, j2) = B(j, j2) - sum * tau * v[2];
304 }
305 for (idx_t j = 0; j < n; ++j) {
306 T sum = Z(j, j0) + v[1] * Z(j, j1) + v[2] * Z(j, j2);
307 Z(j, j0) = Z(j, j0) - sum * tau;
308 Z(j, j1) = Z(j, j1) - sum * tau * v[1];
309 Z(j, j2) = Z(j, j2) - sum * tau * v[2];
310 }
311
312 if (use_b) {
313 v[0] = B(j0, j0);
314 v[1] = B(j1, j0);
315 v[2] = B(j2, j0);
316 }
317 else {
318 v[0] = A(j0, j0);
319 v[1] = A(j1, j0);
320 v[2] = A(j2, j0);
321 }
323
324 // Apply update from the left
325 for (idx_t j = j0; j < n; ++j) {
326 T sum = A(j0, j) + v[1] * A(j1, j) + v[2] * A(j2, j);
327 A(j0, j) = A(j0, j) - sum * tau;
328 A(j1, j) = A(j1, j) - sum * tau * v[1];
329 A(j2, j) = A(j2, j) - sum * tau * v[2];
330 }
331 for (idx_t j = j0; j < n; ++j) {
332 T sum = B(j0, j) + v[1] * B(j1, j) + v[2] * B(j2, j);
333 B(j0, j) = B(j0, j) - sum * tau;
334 B(j1, j) = B(j1, j) - sum * tau * v[1];
335 B(j2, j) = B(j2, j) - sum * tau * v[2];
336 }
337 for (idx_t j = 0; j < n; ++j) {
338 T sum = Q(j, j0) + v[1] * Q(j, j1) + v[2] * Q(j, j2);
339 Q(j, j0) = Q(j, j0) - sum * tau;
340 Q(j, j1) = Q(j, j1) - sum * tau * v[1];
341 Q(j, j2) = Q(j, j2) - sum * tau * v[2];
342 }
343
344 A(j1, j0) = (T)0;
345 A(j2, j0) = (T)0;
346 B(j1, j0) = (T)0;
347 B(j2, j0) = (T)0;
348 }
349 if (n1 == 2 and n2 == 2) {
350 //
351 // Swap 2-by-2 block with 2-by-2 block
352 //
353 std::vector<T> M_;
354 auto M = new_matrix(M_, 8, 8);
355 std::vector<T> x(8);
356 std::vector<idx_t> piv(8);
357
358 for (idx_t j = 0; j < 8; ++j)
359 for (idx_t i = 0; i < 8; ++i)
360 M(i, j) = (T)0;
361
362 // Construct matrix with kronecker structure
363 // I (x) A00
364 M(0, 0) = A(j0, j0);
365 M(0, 1) = A(j0, j1);
366 M(1, 0) = A(j1, j0);
367 M(1, 1) = A(j1, j1);
368 M(2, 2) = A(j0, j0);
369 M(2, 3) = A(j0, j1);
370 M(3, 2) = A(j1, j0);
371 M(3, 3) = A(j1, j1);
372 // I (x) B00
373 M(4, 0) = B(j0, j0);
374 M(4, 1) = B(j0, j1);
375 M(5, 0) = B(j1, j0);
376 M(5, 1) = B(j1, j1);
377 M(6, 2) = B(j0, j0);
378 M(6, 3) = B(j0, j1);
379 M(7, 2) = B(j1, j0);
380 M(7, 3) = B(j1, j1);
381 // A11T (x) I
382 M(0, 4) = -A(j2, j2);
383 M(0, 5) = -A(j3, j2);
384 M(1, 6) = -A(j2, j2);
385 M(1, 7) = -A(j3, j2);
386 M(2, 4) = -A(j2, j3);
387 M(2, 5) = -A(j3, j3);
388 M(3, 6) = -A(j2, j3);
389 M(3, 7) = -A(j3, j3);
390 // B11T (x) I
391 M(4, 4) = -B(j2, j2);
392 M(4, 5) = -B(j3, j2);
393 M(5, 6) = -B(j2, j2);
394 M(5, 7) = -B(j3, j2);
395 M(6, 4) = -B(j2, j3);
396 M(6, 5) = -B(j3, j3);
397 M(7, 6) = -B(j2, j3);
398 M(7, 7) = -B(j3, j3);
399 // RHS
400 x[0] = A(j0, j2);
401 x[1] = A(j1, j2);
402 x[2] = A(j0, j3);
403 x[3] = A(j1, j3);
404 x[4] = B(j0, j2);
405 x[5] = B(j1, j2);
406 x[6] = B(j0, j3);
407 x[7] = B(j1, j3);
408 // LU of M
409 int ierr = getrf(M, piv);
410 if (ierr != 0) return 1;
411 // Apply pivot to rhs
412 for (idx_t i = 0; i < 8; ++i) {
413 if (i != piv[i]) std::swap(x[i], x[piv[i]]);
414 }
415 // Solve Ly = rhs
417 // Solve Ux = y
419
420 // Find Zc so that
421 // [ -x[0] -x[2] ] [ * * ]
422 // Zc^T [ -x[1] -x[3] ] = [ * * ]
423 // [ 1 0 ] [ 0 0 ]
424 // [ 0 1 ] [ 0 0 ]
425
426 // Rotation to make X upper triangular
427 T cxl1, sxl1;
428 rotg(x[0], x[1], cxl1, sxl1);
429 x[1] = (T)0;
430 T rottemp = cxl1 * x[2] + sxl1 * x[3];
431 x[3] = -sxl1 * x[2] + cxl1 * x[3];
432 x[2] = rottemp;
433 // SVD of (upper triangular) X
434 T cxl2, sxl2, cxr, sxr, ssx1, ssx2;
435 svd22(x[0], x[2], x[3], ssx2, ssx1, cxl2, sxl2, cxr, sxr);
436 // Fuse left rotations
437 T cxl, sxl;
438 cxl = cxl1 * cxl2 - sxl1 * sxl2;
439 sxl = cxl2 * sxl1 + sxl2 * cxl1;
440 // Rotations based on the singular values
441 ssx1 = -ssx1;
442 ssx2 = -ssx2;
443 T temp = (T)1;
444 T cx1, sx1, cx2, sx2;
445 rotg(ssx1, temp, cx1, sx1);
446 temp = (T)1;
447 rotg(ssx2, temp, cx2, sx2);
448
449 // Find Qc so that
450 // [ 1 0 ] [ 0 0 ]
451 // Qc^T [ 0 1 ] = [ 0 0 ]
452 // [ x[4] x[6] ] [ * * ]
453 // [ x[5] x[7] ] [ * * ]
454
455 // Rotation to make Y^T upper triangular
456 T cyl1, syl1;
457 rotg(x[4], x[5], cyl1, syl1);
458 x[5] = (T)0;
459 rottemp = cyl1 * x[6] + syl1 * x[7];
460 x[7] = -syl1 * x[6] + cyl1 * x[7];
461 x[6] = rottemp;
462 // SVD of (upper triangular) Y
463 T cyl2, syl2, cyr, syr, ssy1, ssy2;
464 svd22(x[4], x[6], x[7], ssy2, ssy1, cyl2, syl2, cyr, syr);
465 // Fuse left rotations
466 T cyl, syl;
467 cyl = cyl1 * cyl2 - syl1 * syl2;
468 syl = cyl2 * syl1 + syl2 * cyl1;
469 // Rotations based on the singular values
470 temp = (T)1;
471 T cy1, sy1, cy2, sy2;
472 rotg(ssy1, temp, cy1, sy1);
473 temp = (T)1;
474 rotg(ssy2, temp, cy2, sy2);
475
476 // Perform the swap on a local matrix and check the error
477 std::vector<T> AA_;
478 auto AA = new_matrix(AA_, 4, 4);
479 std::vector<T> BB_;
480 auto BB = new_matrix(BB_, 4, 4);
481
482 lacpy(GENERAL, slice(A, range(j0, j3 + 1), range(j0, j3 + 1)), AA);
483 lacpy(GENERAL, slice(B, range(j0, j3 + 1), range(j0, j3 + 1)), BB);
484
485 auto norma = lange(FROB_NORM, AA);
486 auto normb = lange(FROB_NORM, BB);
487
488 // Apply rotations from the left to local matrices
489 {
490 auto a0 = row(AA, 0);
491 auto a1 = row(AA, 1);
492 auto a2 = row(AA, 2);
493 auto a3 = row(AA, 3);
494 rot(a0, a1, cyr, syr);
495 rot(a2, a3, cyl, syl);
496 rot(a2, a0, cy1, sy1);
497 rot(a3, a1, cy2, sy2);
498
499 auto b0 = row(BB, 0);
500 auto b1 = row(BB, 1);
501 auto b2 = row(BB, 2);
502 auto b3 = row(BB, 3);
503 rot(b0, b1, cyr, syr);
504 rot(b2, b3, cyl, syl);
505 rot(b2, b0, cy1, sy1);
506 rot(b3, b1, cy2, sy2);
507 }
508 // Apply rotations from the right to local matrices
509 {
510 auto a0 = col(AA, 0);
511 auto a1 = col(AA, 1);
512 auto a2 = col(AA, 2);
513 auto a3 = col(AA, 3);
514 rot(a0, a1, cxl, sxl);
515 rot(a2, a3, cxr, sxr);
516 rot(a0, a2, cx1, sx1);
517 rot(a1, a3, cx2, sx2);
518
519 auto b0 = col(BB, 0);
520 auto b1 = col(BB, 1);
521 auto b2 = col(BB, 2);
522 auto b3 = col(BB, 3);
523 rot(b0, b1, cxl, sxl);
524 rot(b2, b3, cxr, sxr);
525 rot(b0, b2, cx1, sx1);
526 rot(b1, b3, cx2, sx2);
527 }
528
529 // Weak stability test
530 auto enorma = lange(FROB_NORM, slice(AA, range(2, 4), range(0, 2)));
531 auto enormb = lange(FROB_NORM, slice(BB, range(2, 4), range(0, 2)));
532 const T eps = ulp<T>();
533 const T small_num = safe_min<T>();
534 if (enorma > max((T)20 * norma * eps, small_num)) return 1;
535 if (enormb > max((T)20 * normb * eps, small_num)) return 1;
536
537 // TODO: strong stability test
538
539 // Apply rotations from the left
540 {
541 auto a0 = slice(A, j0, range(j0, n));
542 auto a1 = slice(A, j1, range(j0, n));
543 auto a2 = slice(A, j2, range(j0, n));
544 auto a3 = slice(A, j3, range(j0, n));
545 rot(a0, a1, cyr, syr);
546 rot(a2, a3, cyl, syl);
547 rot(a2, a0, cy1, sy1);
548 rot(a3, a1, cy2, sy2);
549
550 auto b0 = slice(B, j0, range(j0, n));
551 auto b1 = slice(B, j1, range(j0, n));
552 auto b2 = slice(B, j2, range(j0, n));
553 auto b3 = slice(B, j3, range(j0, n));
554 rot(b0, b1, cyr, syr);
555 rot(b2, b3, cyl, syl);
556 rot(b2, b0, cy1, sy1);
557 rot(b3, b1, cy2, sy2);
558
559 auto q0 = col(Q, j0);
560 auto q1 = col(Q, j1);
561 auto q2 = col(Q, j2);
562 auto q3 = col(Q, j3);
563 rot(q0, q1, cyr, syr);
564 rot(q2, q3, cyl, syl);
565 rot(q2, q0, cy1, sy1);
566 rot(q3, q1, cy2, sy2);
567 }
568
569 // Apply rotations from the right
570 {
571 auto a0 = slice(A, range(0, j3 + 1), j0);
572 auto a1 = slice(A, range(0, j3 + 1), j1);
573 auto a2 = slice(A, range(0, j3 + 1), j2);
574 auto a3 = slice(A, range(0, j3 + 1), j3);
575 rot(a0, a1, cxl, sxl);
576 rot(a2, a3, cxr, sxr);
577 rot(a0, a2, cx1, sx1);
578 rot(a1, a3, cx2, sx2);
579
580 auto b0 = slice(B, range(0, j3 + 1), j0);
581 auto b1 = slice(B, range(0, j3 + 1), j1);
582 auto b2 = slice(B, range(0, j3 + 1), j2);
583 auto b3 = slice(B, range(0, j3 + 1), j3);
584 rot(b0, b1, cxl, sxl);
585 rot(b2, b3, cxr, sxr);
586 rot(b0, b2, cx1, sx1);
587 rot(b1, b3, cx2, sx2);
588
589 auto z0 = col(Z, j0);
590 auto z1 = col(Z, j1);
591 auto z2 = col(Z, j2);
592 auto z3 = col(Z, j3);
593 rot(z0, z1, cxl, sxl);
594 rot(z2, z3, cxr, sxr);
595 rot(z0, z2, cx1, sx1);
596 rot(z1, z3, cx2, sx2);
597 }
598
599 // Set relevant parts to zero
600 A(j2, j0) = (T)0;
601 A(j3, j0) = (T)0;
602 A(j2, j1) = (T)0;
603 A(j3, j1) = (T)0;
604 B(j2, j0) = (T)0;
605 B(j3, j0) = (T)0;
606 B(j2, j1) = (T)0;
607 B(j3, j1) = (T)0;
608 }
609
610 // Standardize the 2x2 Schur blocks (if any)
611 if (n2 == 2) {
612 // Make B upper triangular
613 T cl1, sl1;
614 rotg(B(j0, j0), B(j1, j0), cl1, sl1);
615 B(j1, j0) = (T)0;
616 {
617 auto b1 = slice(B, j0, range(j1, j2));
618 auto b2 = slice(B, j1, range(j1, j2));
619 rot(b1, b2, cl1, sl1);
620 }
621 // Standard form
622 T ssmin, ssmax, cl2, sl2, cr, sr;
623 svd22(B(j0, j0), B(j0, j1), B(j1, j1), ssmin, ssmax, cl2, sl2, cr, sr);
624 if (ssmax < (T)0) {
625 cr = -cr;
626 sr = -sr;
627 ssmin = -ssmin;
628 ssmax = -ssmax;
629 }
630 B(j0, j0) = ssmax;
631 B(j1, j1) = ssmin;
632 B(j0, j1) = (T)0;
633 // Fuse left rotations
634 T cl, sl;
635 cl = cl1 * cl2 - sl1 * sl2;
636 sl = cl2 * sl1 + sl2 * cl1;
637 // Apply left rotation
638 {
639 auto a1 = slice(A, j0, range(j0, n));
640 auto a2 = slice(A, j1, range(j0, n));
641 rot(a1, a2, cl, sl);
642 auto b1 = slice(B, j0, range(j2, n));
643 auto b2 = slice(B, j1, range(j2, n));
644 rot(b1, b2, cl, sl);
645 auto q0 = col(Q, j0);
646 auto q1 = col(Q, j1);
647 rot(q0, q1, cl, sl);
648 }
649 // Apply right rotation
650 {
651 auto a1 = slice(A, range(0, j2), j0);
652 auto a2 = slice(A, range(0, j2), j1);
653 rot(a1, a2, cr, sr);
654 auto b1 = slice(B, range(0, j0), j0);
655 auto b2 = slice(B, range(0, j0), j1);
656 rot(b1, b2, cr, sr);
657 auto z0 = col(Z, j0);
658 auto z1 = col(Z, j1);
659 rot(z0, z1, cr, sr);
660 }
661 }
662 if (n1 == 2) {
663 // Make B upper triangular
664 T cl1, sl1;
665 rotg(B(j0 + n2, j0 + n2), B(j1 + n2, j0 + n2), cl1, sl1);
666 B(j1 + n2, j0 + n2) = (T)0;
667 {
668 auto b1 = slice(B, j0 + n2, range(j1 + n2, j2 + n2));
669 auto b2 = slice(B, j1 + n2, range(j1 + n2, j2 + n2));
670 rot(b1, b2, cl1, sl1);
671 }
672 // Standard form
673 T ssmin, ssmax, cl2, sl2, cr, sr;
674 svd22(B(j0 + n2, j0 + n2), B(j0 + n2, j1 + n2), B(j1 + n2, j1 + n2),
675 ssmin, ssmax, cl2, sl2, cr, sr);
676 if (ssmax < (T)0) {
677 cr = -cr;
678 sr = -sr;
679 ssmin = -ssmin;
680 ssmax = -ssmax;
681 }
682 B(j0 + n2, j0 + n2) = ssmax;
683 B(j1 + n2, j1 + n2) = ssmin;
684 B(j0 + n2, j1 + n2) = (T)0;
685 // Fuse left rotations
686 T cl, sl;
687 cl = cl1 * cl2 - sl1 * sl2;
688 sl = cl2 * sl1 + sl2 * cl1;
689 // Apply left rotation
690 {
691 auto a1 = slice(A, j0 + n2, range(j0 + n2, n));
692 auto a2 = slice(A, j1 + n2, range(j0 + n2, n));
693 rot(a1, a2, cl, sl);
694 auto b1 = slice(B, j0 + n2, range(j2 + n2, n));
695 auto b2 = slice(B, j1 + n2, range(j2 + n2, n));
696 rot(b1, b2, cl, sl);
697 auto q0 = col(Q, j0 + n2);
698 auto q1 = col(Q, j1 + n2);
699 rot(q0, q1, cl, sl);
700 }
701 // Apply right rotation
702 {
703 auto a1 = slice(A, range(0, j2 + n2), j0 + n2);
704 auto a2 = slice(A, range(0, j2 + n2), j1 + n2);
705 rot(a1, a2, cr, sr);
706 auto b1 = slice(B, range(0, j0 + n2), j0 + n2);
707 auto b2 = slice(B, range(0, j0 + n2), j1 + n2);
708 rot(b1, b2, cr, sr);
709 auto z0 = col(Z, j0 + n2);
710 auto z1 = col(Z, j1 + n2);
711 rot(z0, z1, cr, sr);
712 }
713 }
714
715 return 0;
716}
717
724template <TLAPACK_CSMATRIX matrix_t,
725 enable_if_t<is_complex<type_t<matrix_t>>, bool> = true>
726int generalized_schur_swap(bool want_q,
727 bool want_z,
728 matrix_t& A,
729 matrix_t& B,
730 matrix_t& Q,
731 matrix_t& Z,
732 const size_type<matrix_t>& j0,
733 const size_type<matrix_t>& n1,
734 const size_type<matrix_t>& n2)
735{
736 using idx_t = size_type<matrix_t>;
737 using T = type_t<matrix_t>;
738 using real_t = real_type<T>;
739 using range = pair<idx_t, idx_t>;
740
741 const idx_t n = ncols(A);
742
743 tlapack_check(nrows(A) == n);
744 tlapack_check(nrows(Q) == n);
745 tlapack_check(ncols(Q) == n);
746 tlapack_check(0 <= j0 and j0 < n);
747 tlapack_check(n1 == 1);
748 tlapack_check(n2 == 1);
749
750 const idx_t j1 = j0 + 1;
751
752 //
753 // In the complex case, there can only be 1x1 blocks to swap
754 //
755 const T a00 = A(j0, j0);
756 const T a01 = A(j0, j1);
757 const T a11 = A(j1, j1);
758 const T b00 = B(j0, j0);
759 const T b01 = B(j0, j1);
760 const T b11 = B(j1, j1);
761
762 const bool use_b = abs(b11 * a00) > abs(b00 * a11);
763
764 //
765 // Determine the transformation to perform the interchange
766 //
767 real_t cl, cr;
768 T sl, sr;
769 T temp = b11 * a00 - a11 * b00;
770 T temp2 = b11 * a01 - a11 * b01;
771 rotg(temp2, temp, cr, sr);
772
773 // Apply transformation from the right
774 {
775 auto a1 = slice(A, range{0, j1 + 1}, j0);
776 auto a2 = slice(A, range{0, j1 + 1}, j1);
777 rot(a2, a1, cr, sr);
778 auto b1 = slice(B, range{0, j1 + 1}, j0);
779 auto b2 = slice(B, range{0, j1 + 1}, j1);
780 rot(b2, b1, cr, sr);
781 if (want_z) {
782 auto z1 = col(Z, j0);
783 auto z2 = col(Z, j1);
784 rot(z2, z1, cr, sr);
785 }
786 }
787
788 if (use_b) {
789 temp = B(j0, j0);
790 temp2 = B(j1, j0);
791 }
792 else {
793 temp = A(j0, j0);
794 temp2 = A(j1, j0);
795 }
796 rotg(temp, temp2, cl, sl);
797
798 // Apply transformation from the left
799 {
800 auto a1 = slice(A, j0, range{j0, n});
801 auto a2 = slice(A, j1, range{j0, n});
802 rot(a1, a2, cl, sl);
803 auto b1 = slice(B, j0, range{j0, n});
804 auto b2 = slice(B, j1, range{j0, n});
805 rot(b1, b2, cl, sl);
806 if (want_q) {
807 auto q1 = col(Q, j0);
808 auto q2 = col(Q, j1);
809 rot(q1, q2, cl, conj(sl));
810 }
811 }
812
813 A(j1, j0) = (T)0;
814 B(j1, j0) = (T)0;
815
816 return 0;
817}
818
819} // namespace tlapack
820
821#endif // TLAPACK_GENERALIZED_SCHUR_SWAP_HH
constexpr internal::FrobNorm FROB_NORM
Frobenius norm of matrices.
Definition types.hpp:342
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:183
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:181
constexpr internal::Forward FORWARD
Forward direction.
Definition types.hpp:376
constexpr internal::UnitDiagonal UNIT_DIAG
The main diagonal is assumed to consist of 1's.
Definition types.hpp:217
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::NonUnitDiagonal NON_UNIT_DIAG
The main diagonal is not assumed to consist of 1's.
Definition types.hpp:215
constexpr internal::ColumnwiseStorage COLUMNWISE_STORAGE
Columnwise storage.
Definition types.hpp:409
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_CSMATRIX
Macro for tlapack::concepts::ConstructableAndSliceableMatrix compatible with C++17.
Definition concepts.hpp:961
int generalized_schur_swap(bool want_q, bool want_z, matrix_t &A, matrix_t &B, matrix_t &Q, matrix_t &Z, const size_type< matrix_t > &j0, const size_type< matrix_t > &n1, const size_type< matrix_t > &n2)
schur_swap, swaps 2 eigenvalues of A.
Definition generalized_schur_swap.hpp:58
auto lange(norm_t normType, const matrix_t &A)
Calculates the norm of a matrix.
Definition lange.hpp:38
void svd22(const T &f, const T &g, const T &h, T &ssmin, T &ssmax, T &csl, T &snl, T &csr, T &snr)
Computes the singular value decomposition of a 2-by-2 real triangular matrix.
Definition svd22.hpp:55
void larfg(storage_t storeMode, type_t< vector_t > &alpha, vector_t &x, type_t< vector_t > &tau)
Generates a elementary Householder reflection.
Definition larfg.hpp:73
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void inv_house3(const matrix_t &A, vector_t &v, type_t< vector_t > &tau)
Inv_house calculates a reflector to reduce the first column in a 2x3 matrix A from the right to zero.
Definition inv_house3.hpp:44
void rotg(T &a, T &b, T &c, T &s)
Construct plane rotation that eliminates b, such that:
Definition rotg.hpp:39
void rot(vectorX_t &x, vectorY_t &y, const c_type &c, const s_type &s)
Apply plane rotation:
Definition rot.hpp:44
void trsv(Uplo uplo, Op trans, Diag diag, const matrixA_t &A, vectorX_t &x)
Solve the triangular matrix-vector equation.
Definition trsv.hpp:64
void syr(Uplo uplo, const alpha_t &alpha, const vectorX_t &x, matrixA_t &A)
Symmetric matrix rank-1 update:
Definition syr.hpp:45
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
int getrf(matrix_t &A, piv_t &piv, const GetrfOpts &opts={})
getrf computes an LU factorization of a general m-by-n matrix A.
Definition getrf.hpp:64
void lahqz_eig22(const A_t &A, const B_t &B, complex_type< T > &alpha1, complex_type< T > &alpha2, T &beta1, T &beta2)
Computes the generalized eigenvalues of a 2x2 pencil (A,B) with B upper triangular.
Definition lahqz_eig22.hpp:35
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113