<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
aggressive_early_deflation.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_AED_HH
11#define TLAPACK_AED_HH
12
14#include "tlapack/blas/gemm.hpp"
28
29namespace tlapack {
30
31namespace internal {
32
53 template <class T, TLAPACK_SMATRIX matrix_t>
58 const matrix_t& A)
59 {
60 using idx_t = size_type<matrix_t>;
62
63 const idx_t n = ncols(A);
64 const idx_t nw_max = (n - 3) / 3;
65 const idx_t jw = min(min(nw, ihi - ilo), nw_max);
66
67 if (jw != ihi - ilo) {
68 // Hessenberg reduction
69 auto&& TW = slice(A, range{0, jw}, range{0, jw});
70 auto&& tau = slice(A, range{0, jw}, 0);
71 return gehrd_worksize<T>(0, jw, TW, tau);
72 }
73 else
74 return WorkInfo();
75 }
76} // namespace internal
77
115template <class T,
116 TLAPACK_SMATRIX matrix_t,
117 TLAPACK_SVECTOR vector_t,
118 enable_if_t<is_complex<type_t<vector_t>>, int>>
120 bool want_z,
124 const matrix_t& A,
125 const vector_t& s,
126 const matrix_t& Z,
127 const size_type<matrix_t>& ns,
128 const size_type<matrix_t>& nd,
129 const FrancisOpts& opts)
130{
131 using idx_t = size_type<matrix_t>;
133
134 const idx_t n = ncols(A);
135 const idx_t nw_max = (n - 3) / 3;
136 const idx_t jw = min(min(nw, ihi - ilo), nw_max);
137
138 // quick return
140 if (n < 9 || nw <= 1 || ihi <= 1 + ilo) return workinfo;
141
142 if (jw >= (idx_t)opts.nmin) {
143 auto&& TW = slice(A, range{0, jw}, range{0, jw});
144 auto&& s_window = slice(s, range{0, jw});
145 auto&& V = slice(A, range{0, jw}, range{0, jw});
146 workinfo =
147 multishift_qr_worksize<T>(true, true, 0, jw, TW, s_window, V, opts);
148 }
149
150 workinfo.minMax(internal::aggressive_early_deflation_worksize_gehrd<T>(
151 ilo, ihi, nw, A));
152
153 return workinfo;
154}
155
164template <TLAPACK_SMATRIX matrix_t,
165 TLAPACK_SVECTOR vector_t,
166 TLAPACK_WORKSPACE work_t,
167 enable_if_t<is_complex<type_t<vector_t>>, int>>
169 bool want_z,
173 matrix_t& A,
174 vector_t& s,
175 matrix_t& Z,
178 work_t& work,
180{
181 using T = type_t<matrix_t>;
182 using real_t = real_type<T>;
183 using idx_t = size_type<matrix_t>;
185
186 // Constants
187 const real_t one(1);
188 const real_t zero(0);
189 const idx_t n = ncols(A);
190 // Because we will use the lower triangular part of A as workspace,
191 // We have a maximum window size
192 const idx_t nw_max = (n - 3) / 3;
193 const real_t eps = ulp<real_t>();
194 const real_t small_num = safe_min<real_t>() * ((real_t)n / eps);
195 // Size of the deflation window
196 const idx_t jw = min(min(nw, ihi - ilo), nw_max);
197 // First row index in the deflation window
198 const idx_t kwtop = ihi - jw;
199
200 // check arguments
201 tlapack_check(nrows(A) == n);
202 if (want_z) {
203 tlapack_check(ncols(Z) == n);
204 tlapack_check(nrows(Z) == n);
205 }
206 tlapack_check((idx_t)size(s) == n);
207
208 // s is the value just outside the window. It determines the spike
209 // together with the orthogonal schur factors.
210 T s_spike;
211 if (kwtop == ilo)
212 s_spike = zero;
213 else
214 s_spike = A(kwtop, kwtop - 1);
215
216 if (kwtop + 1 == ihi) {
217 // 1x1 deflation window, not much to do
218 s[kwtop] = A(kwtop, kwtop);
219 ns = 1;
220 nd = 0;
221 if (abs1(s_spike) <= max(small_num, eps * abs1(A(kwtop, kwtop)))) {
222 ns = 0;
223 nd = 1;
224 if (kwtop > ilo) A(kwtop, kwtop - 1) = zero;
225 }
226 return;
227 // Note: The max() above may not propagate a NaN in A(kwtop, kwtop).
228 }
229
230 // Define workspace matrices
231 // We use the lower triangular part of A as workspace
232 // TW and WH overlap, but WH is only used after we no longer need
233 // TW so it is ok.
234 auto V = slice(A, range{n - jw, n}, range{0, jw});
235 auto TW = slice(A, range{n - jw, n}, range{jw, 2 * jw});
236 auto WH = slice(A, range{n - jw, n}, range{jw, n - jw - 3});
237 auto WV = slice(A, range{jw + 3, n - jw}, range{0, jw});
238
239 // Convert the window to spike-triangular form. i.e. calculate the
240 // Schur form of the deflation window.
241 // If the QR algorithm fails to convergence, it can still be
242 // partially in Schur form. In that case we continue on a smaller
243 // window (note the use of infqr later in the code).
244 auto A_window = slice(A, range{kwtop, ihi}, range{kwtop, ihi});
245 auto s_window = slice(s, range{kwtop, ihi});
247 for (idx_t j = 0; j < jw; ++j)
248 for (idx_t i = 0; i < min(j + 2, jw); ++i)
249 TW(i, j) = A_window(i, j);
250 laset(GENERAL, zero, one, V);
251 int infqr;
252 if (jw < (idx_t)opts.nmin)
253 infqr = lahqr(true, true, 0, jw, TW, s_window, V);
254 else {
255 infqr =
256 multishift_qr_work(true, true, 0, jw, TW, s_window, V, work, opts);
257 for (idx_t j = 0; j < jw; ++j)
258 for (idx_t i = j + 2; i < jw; ++i)
259 TW(i, j) = zero;
260 }
261
262 // Deflation detection loop
263 // one eigenvalue block at a time, we will check if it is deflatable
264 // by checking the bottom spike element. If it is not deflatable,
265 // we move the block up. This moves other blocks down to check.
266 ns = jw;
267 idx_t ilst = infqr;
268 while (ilst < ns) {
269 bool bulge = false;
270 if (is_real<T>)
271 if (ns > 1)
272 if (TW(ns - 1, ns - 2) != zero) bulge = true;
273
274 if (!bulge) {
275 // 1x1 eigenvalue block
276 real_t foo = abs1(TW(ns - 1, ns - 1));
277 if (foo == zero) foo = abs1(s_spike);
278 if (abs1(s_spike) * abs1(V(0, ns - 1)) <=
279 max(small_num, eps * foo)) {
280 // Eigenvalue is deflatable
281 ns = ns - 1;
282 }
283 else {
284 // Eigenvalue is not deflatable.
285 // Move it up out of the way.
286 idx_t ifst = ns - 1;
287 schur_move(true, TW, V, ifst, ilst);
288 ilst = ilst + 1;
289 }
290 // Note: The max() above may not propagate a NaN in TW(ns-1, ns-1).
291 }
292 else {
293 // 2x2 eigenvalue block
294 real_t foo =
295 abs(TW(ns - 1, ns - 1)) +
296 sqrt(abs(TW(ns - 1, ns - 2))) * sqrt(abs(TW(ns - 2, ns - 1)));
297 if (foo == zero) foo = abs(s_spike);
298 if (max(abs(s_spike * V(0, ns - 1)), abs(s_spike * V(0, ns - 2))) <=
300 // Eigenvalue pair is deflatable
301 ns = ns - 2;
302 }
303 else {
304 // Eigenvalue pair is not deflatable.
305 // Move it up out of the way.
306 idx_t ifst = ns - 2;
307 schur_move(true, TW, V, ifst, ilst);
308 ilst = ilst + 2;
309 }
310 }
311 }
312
313 if (ns == 0) s_spike = zero;
314
315 if (ns == jw) {
316 // Agressive early deflation didn't deflate any eigenvalues
317 // We don't need to apply the update to the rest of the matrix
318 nd = jw - ns;
319 ns = ns - infqr;
320 return;
321 }
322
323 // sorting diagonal blocks of T improves accuracy for graded matrices.
324 // Bubble sort deals well with exchange failures.
325 bool sorted = false;
326 // Window to be checked (other eigenvalue are sorted)
327 idx_t sorting_window_size = jw;
328 while (!sorted) {
329 sorted = true;
330
331 // Index of last eigenvalue that was swapped
332 idx_t ilst = 0;
333
334 // Index of the first block
335 idx_t i1 = ns;
336
337 while (i1 + 1 < sorting_window_size) {
338 // Size of the first block
339 idx_t n1 = 1;
340 if (is_real<T>)
341 if (TW(i1 + 1, i1) != zero) n1 = 2;
342
343 // Check if there is a next block
344 if (i1 + n1 == jw) {
345 ilst = ilst - n1;
346 break;
347 }
348
349 // Index of the second block
350 idx_t i2 = i1 + n1;
351
352 // Size of the second block
353 idx_t n2 = 1;
354 if (is_real<T>)
355 if (i2 + 1 < jw)
356 if (TW(i2 + 1, i2) != zero) n2 = 2;
357
358 real_t ev1, ev2;
359 if (n1 == 1)
360 ev1 = abs1(TW(i1, i1));
361 else
362 ev1 = abs(TW(i1, i1)) +
363 sqrt(abs(TW(i1 + 1, i1))) * sqrt(abs(TW(i1, i1 + 1)));
364 if (n2 == 1)
365 ev2 = abs1(TW(i2, i2));
366 else
367 ev2 = abs(TW(i2, i2)) +
368 sqrt(abs(TW(i2 + 1, i2))) * sqrt(abs(TW(i2, i2 + 1)));
369
370 if (ev1 > ev2) {
371 i1 = i2;
372 }
373 else {
374 sorted = false;
375 int ierr = schur_swap(true, TW, V, i1, n1, n2);
376 if (ierr == 0)
377 i1 = i1 + n2;
378 else
379 i1 = i2;
380 ilst = i1;
381 }
382 }
384 }
385
386 // Recalculate the eigenvalues
387 idx_t i = 0;
388 while (i < jw) {
389 idx_t n1 = 1;
390 if (is_real<T>)
391 if (i + 1 < jw)
392 if (TW(i + 1, i) != zero) n1 = 2;
393
394 if (n1 == 1)
395 s[kwtop + i] = TW(i, i);
396 else
397 lahqr_eig22(TW(i, i), TW(i, i + 1), TW(i + 1, i), TW(i + 1, i + 1),
398 s[kwtop + i], s[kwtop + i + 1]);
399 i = i + n1;
400 }
401
402 // Reduce A back to Hessenberg form (if neccesary)
403 if (s_spike != zero) {
404 // Reflect spike back
405 {
406 T tau;
407 auto v = slice(WV, range{0, ns}, 0);
408 for (idx_t i = 0; i < ns; ++i) {
409 v[i] = conj(V(0, i));
410 }
412
413 auto Wv_aux = slice(WV, range{0, jw}, range{1, 2});
414
415 auto TW_slice = slice(TW, range{0, ns}, range{0, jw});
418
419 auto TW_slice2 = slice(TW, range{0, jw}, range{0, ns});
422
423 auto V_slice = slice(V, range{0, jw}, range{0, ns});
425 Wv_aux);
426 }
427
428 // Hessenberg reduction
429 {
430 auto tau = slice(WV, range{0, jw}, 0);
431 gehrd_work(0, ns, TW, tau, work);
432
433 auto work2 = slice(WV, range{0, jw}, range{1, 2});
435 }
436 }
437
438 // Copy the deflation window back into place
439 if (kwtop > 0) A(kwtop, kwtop - 1) = s_spike * conj(V(0, 0));
440 for (idx_t j = 0; j < jw; ++j)
441 for (idx_t i = 0; i < min(j + 2, jw); ++i)
442 A(kwtop + i, kwtop + j) = TW(i, j);
443
444 // Store number of deflated eigenvalues
445 nd = jw - ns;
446 ns = ns - infqr;
447
448 //
449 // Update rest of the matrix using matrix matrix multiplication
450 //
451 idx_t istart_m, istop_m;
452 if (want_t) {
453 istart_m = 0;
454 istop_m = n;
455 }
456 else {
457 istart_m = ilo;
458 istop_m = ihi;
459 }
460 // Horizontal multiply
461 if (ihi < istop_m) {
462 idx_t i = ihi;
463 while (i < istop_m) {
464 idx_t iblock = std::min<idx_t>(istop_m - i, ncols(WH));
465 auto A_slice = slice(A, range{kwtop, ihi}, range{i, i + iblock});
466 auto WH_slice =
467 slice(WH, range{0, nrows(A_slice)}, range{0, ncols(A_slice)});
470 i = i + iblock;
471 }
472 }
473 // Vertical multiply
474 if (istart_m < kwtop) {
475 idx_t i = istart_m;
476 while (i < kwtop) {
477 idx_t iblock = std::min<idx_t>(kwtop - i, nrows(WV));
478 auto A_slice = slice(A, range{i, i + iblock}, range{kwtop, ihi});
479 auto WV_slice =
480 slice(WV, range{0, nrows(A_slice)}, range{0, ncols(A_slice)});
483 i = i + iblock;
484 }
485 }
486 // Update Z (also a vertical multiplication)
487 if (want_z) {
488 idx_t i = 0;
489 while (i < n) {
490 idx_t iblock = std::min<idx_t>(n - i, nrows(WV));
491 auto Z_slice = slice(Z, range{i, i + iblock}, range{kwtop, ihi});
492 auto WV_slice =
493 slice(WV, range{0, nrows(Z_slice)}, range{0, ncols(Z_slice)});
496 i = i + iblock;
497 }
498 }
499}
500
515template <TLAPACK_MATRIX matrix_t,
516 TLAPACK_VECTOR vector_t,
517 TLAPACK_MATRIX work_t,
518 enable_if_t<is_complex<type_t<vector_t>>, int> = 0>
519void aggressive_early_deflation_work(bool want_t,
520 bool want_z,
521 size_type<matrix_t> ilo,
522 size_type<matrix_t> ihi,
523 size_type<matrix_t> nw,
524 matrix_t& A,
525 vector_t& s,
526 matrix_t& Z,
527 size_type<matrix_t>& ns,
528 size_type<matrix_t>& nd,
529 work_t& work)
530{
531 FrancisOpts opts = {};
532 aggressive_early_deflation_work(want_t, want_z, ilo, ihi, nw, A, s, Z, ns,
533 nd, work, opts);
534}
535
593template <TLAPACK_MATRIX matrix_t,
594 TLAPACK_VECTOR vector_t,
595 enable_if_t<is_complex<type_t<vector_t>>, int> = 0>
597 bool want_z,
601 matrix_t& A,
602 vector_t& s,
603 matrix_t& Z,
607{
608 using T = type_t<matrix_t>;
609 using real_t = real_type<T>;
610 using idx_t = size_type<matrix_t>;
611
612 // Functors
614
615 // Constants
616 const real_t zero(0);
617 const idx_t n = ncols(A);
618 // Because we will use the lower triangular part of A as workspace,
619 // We have a maximum window size
620 const idx_t nw_max = (n - 3) / 3;
621 const real_t eps = ulp<real_t>();
622 const real_t small_num = safe_min<real_t>() * ((real_t)n / eps);
623 // Size of the deflation window
624 const idx_t jw = min(min(nw, ihi - ilo), nw_max);
625 // First row index in the deflation window
626 const idx_t kwtop = ihi - jw;
627
628 // s is the value just outside the window. It determines the spike
629 // together with the orthogonal schur factors.
630 T s_spike;
631 if (kwtop == ilo)
632 s_spike = zero;
633 else
634 s_spike = A(kwtop, kwtop - 1);
635
636 if (kwtop + 1 == ihi) {
637 // 1x1 deflation window, not much to do
638 s[kwtop] = A(kwtop, kwtop);
639 ns = 1;
640 nd = 0;
641 if (abs1(s_spike) <= max(small_num, eps * abs1(A(kwtop, kwtop)))) {
642 ns = 0;
643 nd = 1;
644 if (kwtop > ilo) A(kwtop, kwtop - 1) = zero;
645 }
646 return;
647 }
648
649 // Allocates workspace
651 want_t, want_z, ilo, ihi, nw, A, s, Z, ns, nd, opts);
652 std::vector<T> work_;
653 auto work = new_matrix(work_, workinfo.m, workinfo.n);
654
656 nd, work, opts);
657}
658
672template <TLAPACK_MATRIX matrix_t,
673 TLAPACK_VECTOR vector_t,
674 enable_if_t<is_complex<type_t<vector_t>>, int> = 0>
675void aggressive_early_deflation(bool want_t,
676 bool want_z,
677 size_type<matrix_t> ilo,
678 size_type<matrix_t> ihi,
679 size_type<matrix_t> nw,
680 matrix_t& A,
681 vector_t& s,
682 matrix_t& Z,
683 size_type<matrix_t>& ns,
684 size_type<matrix_t>& nd)
685{
686 FrancisOpts opts = {};
687 aggressive_early_deflation(want_t, want_z, ilo, ihi, nw, A, s, Z, ns, nd,
688 opts);
689}
690
691} // namespace tlapack
692
693#endif // TLAPACK_AED_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:183
constexpr internal::RightSide RIGHT_SIDE
right side
Definition types.hpp:291
constexpr internal::Forward FORWARD
Forward direction.
Definition types.hpp:376
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
constexpr internal::ColumnwiseStorage COLUMNWISE_STORAGE
Columnwise storage.
Definition types.hpp:409
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr internal::LeftSide LEFT_SIDE
left side
Definition types.hpp:289
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
constexpr real_type< T > abs1(const T &x)
1-norm absolute value, |Re(x)| + |Im(x)|
Definition utils.hpp:133
#define TLAPACK_SVECTOR
Macro for tlapack::concepts::SliceableVector compatible with C++17.
Definition concepts.hpp:909
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void aggressive_early_deflation(bool want_t, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, size_type< matrix_t > nw, matrix_t &A, vector_t &s, matrix_t &Z, size_type< matrix_t > &ns, size_type< matrix_t > &nd, FrancisOpts &opts)
aggressive_early_deflation accepts as input an upper Hessenberg matrix H and performs an orthogonal s...
Definition aggressive_early_deflation.hpp:596
void lahqr_eig22(T a00, T a01, T a10, T a11, complex_type< T > &s1, complex_type< T > &s2)
Computes the eigenvalues of a 2x2 matrix A.
Definition lahqr_eig22.hpp:34
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
int lahqr(bool want_t, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, matrix_t &A, vector_t &w, matrix_t &Z)
lahqr computes the eigenvalues and optionally the Schur factorization of an upper Hessenberg matrix,...
Definition lahqr.hpp:73
int schur_swap(bool want_q, matrix_t &A, matrix_t &Q, const size_type< matrix_t > &j0, const size_type< matrix_t > &n1, const size_type< matrix_t > &n2)
schur_swap, swaps 2 eigenvalues of A.
Definition schur_swap.hpp:49
void larfg(storage_t storeMode, type_t< vector_t > &alpha, vector_t &x, type_t< vector_t > &tau)
Generates a elementary Householder reflection.
Definition larfg.hpp:73
int schur_move(bool want_q, matrix_t &A, matrix_t &Q, size_type< matrix_t > &ifst, size_type< matrix_t > &ilst)
schur_move reorders the Schur factorization of a matrix S = Q*A*Q**H, so that the diagonal element of...
Definition schur_move.hpp:47
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void larf_work(side_t side, storage_t storeMode, vector_t const &x, const tau_t &tau, vectorC0_t &C0, matrixC1_t &C1, work_t &work)
Applies an elementary reflector defined by tau and v to a m-by-n matrix C decomposed into C0 and C1....
Definition larf.hpp:48
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
int gehrd_work(size_type< matrix_t > ilo, size_type< matrix_t > ihi, matrix_t &A, vector_t &tau, work_t &work, const GehrdOpts &opts={})
Reduces a general square matrix to upper Hessenberg form. Workspace is provided as an argument.
Definition gehrd.hpp:99
void aggressive_early_deflation_work(bool want_t, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, size_type< matrix_t > nw, matrix_t &A, vector_t &s, matrix_t &Z, size_type< matrix_t > &ns, size_type< matrix_t > &nd, work_t &work, FrancisOpts &opts)
aggressive_early_deflation accepts as input an upper Hessenberg matrix H and performs an orthogonal s...
Definition aggressive_early_deflation.hpp:168
int multishift_qr_work(bool want_t, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, matrix_t &A, vector_t &w, matrix_t &Z, work_t &work, FrancisOpts &opts)
multishift_qr computes the eigenvalues and optionally the Schur factorization of an upper Hessenberg ...
Definition multishift_qr.hpp:114
int unmhr_work(Side side, Op trans, size_type< matrix_t > ilo, size_type< matrix_t > ihi, const matrix_t &A, const vector_t &tau, matrix_t &C, work_t &work)
Applies unitary matrix Q to a matrix C. Workspace is provided as an argument.
Definition unmhr.hpp:83
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
constexpr WorkInfo aggressive_early_deflation_worksize_gehrd(size_type< matrix_t > ilo, size_type< matrix_t > ihi, size_type< matrix_t > nw, const matrix_t &A)
Workspace query for gehrd() in aggressive_early_deflation().
Definition aggressive_early_deflation.hpp:54
WorkInfo aggressive_early_deflation_worksize(bool want_t, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, size_type< matrix_t > nw, const matrix_t &A, const vector_t &s, const matrix_t &Z, const size_type< matrix_t > &ns, const size_type< matrix_t > &nd, const FrancisOpts &opts)
Worspace query of aggressive_early_deflation().
Definition aggressive_early_deflation.hpp:119
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Options struct for multishift_qr().
Definition FrancisOpts.hpp:23
Output information in the workspace query.
Definition workspace.hpp:16