<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
aggressive_early_deflation_generalized.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_AED_GENERALIZED_HH
11#define TLAPACK_AED_GENERALIZED_HH
12
14#include "tlapack/blas/gemm.hpp"
24
25namespace tlapack {
26
97template <TLAPACK_SMATRIX matrix_t,
98 TLAPACK_SVECTOR alpha_t,
99 TLAPACK_SVECTOR beta_t>
101 bool want_q,
102 bool want_z,
106 matrix_t& A,
107 matrix_t& B,
108 alpha_t& alpha,
109 beta_t& beta,
110 matrix_t& Q,
111 matrix_t& Z,
115{
116 using T = type_t<matrix_t>;
117 using real_t = real_type<T>;
118 using idx_t = size_type<matrix_t>;
120
121 // Constants
122 const real_t one(1);
123 const real_t zero(0);
124 const idx_t n = ncols(A);
125 // Because we will use the lower triangular part of A as workspace,
126 // We have a maximum window size
127 const idx_t nw_max = (n - 3) / 3;
128 const real_t eps = ulp<real_t>();
129 const real_t small_num = safe_min<real_t>() * ((real_t)n / eps);
130 // Size of the deflation window
131 const idx_t jw = min(min(nw, ihi - ilo), nw_max);
132 // First row index in the deflation window
133 const idx_t kwtop = ihi - jw;
134
135 // check arguments
136 tlapack_check(nrows(A) == n);
137 tlapack_check(ncols(B) == n);
138 tlapack_check(nrows(B) == n);
139 if (want_q) {
140 tlapack_check(ncols(Q) == n);
141 tlapack_check(nrows(Q) == n);
142 }
143 if (want_z) {
144 tlapack_check(ncols(Z) == n);
145 tlapack_check(nrows(Z) == n);
146 }
147 tlapack_check((idx_t)size(alpha) == n);
148
149 // s is the value just outside the window. It determines the spike
150 // together with the orthogonal schur factors.
151 T s_spike;
152 if (kwtop == ilo)
153 s_spike = zero;
154 else
155 s_spike = A(kwtop, kwtop - 1);
156
157 if (kwtop + 1 == ihi) {
158 // 1x1 deflation window, not much to do
159 alpha[kwtop] = A(kwtop, kwtop);
160 beta[kwtop] = B(kwtop, kwtop);
161 ns = 1;
162 nd = 0;
163 if (abs1(s_spike) <= max(small_num, eps * abs1(A(kwtop, kwtop)))) {
164 ns = 0;
165 nd = 1;
166 if (kwtop > ilo) A(kwtop, kwtop - 1) = zero;
167 }
168 return;
169 // Note: The max() above may not propagate a NaN in A(kwtop, kwtop).
170 }
171
172 // Define workspace matrices
173 // We use the lower triangular part of A as workspace
174 // AW and WH overlap, but WH is only used after we no longer need
175 // AW so it is ok.
176 auto Qc = slice(A, range{n - jw, n}, range{0, jw});
177 auto Zc = slice(B, range{n - jw, n}, range{0, jw});
178 auto Aw = slice(A, range{n - jw, n}, range{jw, 2 * jw});
179 auto Bw = slice(B, range{n - jw, n}, range{jw, 2 * jw});
180 auto WH = slice(A, range{n - jw, n}, range{jw, n - jw - 3});
181 auto WV = slice(A, range{jw + 3, n - jw}, range{0, jw});
182
183 // Convert the window to spike-triangular form. i.e. calculate the
184 // Schur form of the deflation window.
185 // If the QZ algorithm fails to converge, it can still be
186 // partially in Schur form. In that case we continue on a smaller
187 // window (note the use of infqz later in the code).
188 auto A_window = slice(A, range{kwtop, ihi}, range{kwtop, ihi});
189 auto B_window = slice(B, range{kwtop, ihi}, range{kwtop, ihi});
190 auto alpha_window = slice(alpha, range{kwtop, ihi});
191 auto beta_window = slice(beta, range{kwtop, ihi});
194 for (idx_t j = 0; j < jw; ++j)
195 for (idx_t i = 0; i < min(j + 2, jw); ++i)
196 Aw(i, j) = A_window(i, j);
197 for (idx_t j = 0; j < jw; ++j)
198 for (idx_t i = 0; i < min(j + 1, jw); ++i)
199 Bw(i, j) = B_window(i, j);
200 laset(GENERAL, zero, one, Qc);
201 laset(GENERAL, zero, one, Zc);
202 int infqz;
203 if (jw < (idx_t)opts.nmin) {
204 infqz = lahqz(true, true, true, 0, jw, Aw, Bw, alpha_window,
205 beta_window, Qc, Zc);
206 }
207 else {
208 infqz = multishift_qz(true, true, true, 0, jw, Aw, Bw, alpha_window,
209 beta_window, Qc, Zc, opts);
210 for (idx_t j = 0; j < jw; ++j)
211 for (idx_t i = j + 2; i < jw; ++i)
212 Aw(i, j) = zero;
213 }
214
215 // TODO: use multishift_qz recursively
216 // if (jw < (idx_t)opts.nmin)
217 // infqr = lahqr(true, true, 0, jw, TW, s_window, V);
218 // else {
219 // infqr =
220 // multishift_qr_work(true, true, 0, jw, TW, s_window, V, work,
221 // opts);
222 // for (idx_t j = 0; j < jw; ++j)
223 // for (idx_t i = j + 2; i < jw; ++i)
224 // TW(i, j) = zero;
225 // }
226
227 // Deflation detection loop
228 // one eigenvalue block at a time, we will check if it is deflatable
229 // by checking the bottom spike element. If it is not deflatable,
230 // we move the block up. This moves other blocks down to check.
231 ns = jw;
232 idx_t ilst = infqz;
233 while (ilst < ns) {
234 bool bulge = false;
235 if (is_real<T>)
236 if (ns > 1)
237 if (Aw(ns - 1, ns - 2) != zero) bulge = true;
238
239 if (!bulge) {
240 // 1x1 eigenvalue block
241 real_t foo = abs1(Aw(ns - 1, ns - 1));
242 if (foo == zero) foo = abs1(s_spike);
243 if (abs1(s_spike) * abs1(Qc(0, ns - 1)) <=
244 max(small_num, eps * foo)) {
245 // Eigenvalue is deflatable
246 ns = ns - 1;
247 }
248 else {
249 // Eigenvalue is not deflatable.
250 // Move it up out of the way.
251 idx_t ifst = ns - 1;
252 generalized_schur_move(true, true, Aw, Bw, Qc, Zc, ifst, ilst);
253 ilst = ilst + 1;
254 }
255 // Note: The max() above may not propagate a NaN in TW(ns-1,ns-1).
256 }
257 else {
258 // 2x2 eigenvalue block
259 real_t foo =
260 abs(Aw(ns - 1, ns - 1)) +
261 sqrt(abs(Aw(ns - 1, ns - 2))) * sqrt(abs(Aw(ns - 2, ns - 1)));
262 if (foo == zero) foo = abs(s_spike);
263 if (max(abs(s_spike * Qc(0, ns - 1)),
264 abs(s_spike * Qc(0, ns - 2))) <=
266 // Eigenvalue pair is deflatable
267 ns = ns - 2;
268 }
269 else {
270 // Eigenvalue pair is not deflatable.
271 // Move it up out of the way.
272 idx_t ifst = ns - 2;
273 generalized_schur_move(true, true, Aw, Bw, Qc, Zc, ifst, ilst);
274 ilst = ilst + 2;
275 }
276 }
277 }
278
279 if (ns == 0) s_spike = zero;
280
281 if (ns == jw) {
282 // Agressive early deflation didn't deflate any eigenvalues
283 // We don't need to apply the update to the rest of the matrix
284 nd = jw - ns;
285 ns = ns - infqz;
286 return;
287 }
288
289 // Recalculate the eigenvalues
290 idx_t i = 0;
291 while (i < jw) {
292 idx_t n1 = 1;
293 if (is_real<T>)
294 if (i + 1 < jw)
295 if (Aw(i + 1, i) != zero) n1 = 2;
296
297 if (n1 == 1) {
298 alpha[kwtop + i] = Aw(i, i);
299 beta[kwtop + i] = Bw(i, i);
300 }
301 else {
302 auto A22 = slice(Aw, range(i, i + 2), range(i, i + 2));
303 auto B22 = slice(Bw, range(i, i + 2), range(i, i + 2));
304 lahqz_eig22(A22, B22, alpha[kwtop + i], alpha[kwtop + i + 1],
305 beta[kwtop + i], beta[kwtop + i + 1]);
306 }
307 i = i + n1;
308 }
309
310 // Reduce A back to Hessenberg form (if neccesary)
311 if (s_spike != zero) {
312 // Use rotations to remove the spike
313 for (idx_t i = ns - 1; i > 0; i--) {
314 T t1 = conj(Qc(0, i - 1));
315 T t2 = conj(Qc(0, i));
316 real_t c;
317 T s;
318 rotg(t1, t2, c, s);
319
320 auto q1 = col(Qc, i - 1);
321 auto q2 = col(Qc, i);
322 rot(q1, q2, c, conj(s));
323 Qc(0, i) = (T)0;
324
325 auto a1 = slice(Aw, i - 1, range(0, jw));
326 auto a2 = slice(Aw, i, range(0, jw));
327 rot(a1, a2, c, s);
328
329 auto b1 = slice(Bw, i - 1, range(0, jw));
330 auto b2 = slice(Bw, i, range(0, jw));
331 rot(b1, b2, c, s);
332 }
333 // Remove fill-in from B
334 for (idx_t i = ns - 1; i > 0; i--) {
335 real_t c;
336 T s;
337 rotg(Bw(i, i), Bw(i, i - 1), c, s);
338 s = -s;
339 Bw(i, i - 1) = (T)0.;
340
341 auto b1 = slice(Bw, range(0, i), i - 1);
342 auto b2 = slice(Bw, range(0, i), i);
343 rot(b1, b2, c, conj(s));
344
345 auto a1 = slice(Aw, range(0, ns), i - 1);
346 auto a2 = slice(Aw, range(0, ns), i);
347 rot(a1, a2, c, conj(s));
348
349 auto z1 = col(Zc, i - 1);
350 auto z2 = col(Zc, i);
351 rot(z1, z2, c, conj(s));
352 }
353
354 // // Hessenberg-triangular reduction
355 gghrd(true, true, 0, ns, Aw, Bw, Qc, Zc);
356 }
357
358 // Copy the deflation window back into place
359 if (kwtop > 0) A(kwtop, kwtop - 1) = s_spike * conj(Qc(0, 0));
360 for (idx_t j = 0; j < jw; ++j)
361 for (idx_t i = 0; i < min(j + 2, jw); ++i)
362 A(kwtop + i, kwtop + j) = Aw(i, j);
363 for (idx_t j = 0; j < jw; ++j)
364 for (idx_t i = 0; i < min(j + 1, jw); ++i)
365 B(kwtop + i, kwtop + j) = Bw(i, j);
366
367 // Store number of deflated eigenvalues
368 nd = jw - ns;
369 ns = ns - infqz;
370
371 //
372 // Update rest of the matrix using matrix matrix multiplication
373 //
374 idx_t istart_m, istop_m;
375 if (want_s) {
376 istart_m = 0;
377 istop_m = n;
378 }
379 else {
380 istart_m = ilo;
381 istop_m = ihi;
382 }
383 // Update A
384 if (ihi < istop_m) {
385 idx_t i = ihi;
386 while (i < istop_m) {
387 idx_t iblock = std::min<idx_t>(istop_m - i, ncols(WH));
388 auto A_slice = slice(A, range{kwtop, ihi}, range{i, i + iblock});
389 auto WH_slice =
390 slice(WH, range{0, nrows(A_slice)}, range{0, ncols(A_slice)});
393 i = i + iblock;
394 }
395 }
396 if (istart_m < kwtop) {
397 idx_t i = istart_m;
398 while (i < kwtop) {
399 idx_t iblock = std::min<idx_t>(kwtop - i, nrows(WV));
400 auto A_slice = slice(A, range{i, i + iblock}, range{kwtop, ihi});
401 auto WV_slice =
402 slice(WV, range{0, nrows(A_slice)}, range{0, ncols(A_slice)});
405 i = i + iblock;
406 }
407 }
408 // Update B
409 if (ihi < istop_m) {
410 idx_t i = ihi;
411 while (i < istop_m) {
412 idx_t iblock = std::min<idx_t>(istop_m - i, ncols(WH));
413 auto B_slice = slice(B, range{kwtop, ihi}, range{i, i + iblock});
414 auto WH_slice =
415 slice(WH, range{0, nrows(B_slice)}, range{0, ncols(B_slice)});
418 i = i + iblock;
419 }
420 }
421 if (istart_m < kwtop) {
422 idx_t i = istart_m;
423 while (i < kwtop) {
424 idx_t iblock = std::min<idx_t>(kwtop - i, nrows(WV));
425 auto B_slice = slice(B, range{i, i + iblock}, range{kwtop, ihi});
426 auto WV_slice =
427 slice(WV, range{0, nrows(B_slice)}, range{0, ncols(B_slice)});
430 i = i + iblock;
431 }
432 }
433 // Update Q
434 if (want_q) {
435 idx_t i = 0;
436 while (i < n) {
437 idx_t iblock = std::min<idx_t>(n - i, nrows(WV));
438 auto Q_slice = slice(Q, range{i, i + iblock}, range{kwtop, ihi});
439 auto WV_slice =
440 slice(WV, range{0, nrows(Q_slice)}, range{0, ncols(Q_slice)});
443 i = i + iblock;
444 }
445 }
446 // Update Z
447 if (want_z) {
448 idx_t i = 0;
449 while (i < n) {
450 idx_t iblock = std::min<idx_t>(n - i, nrows(WV));
451 auto Z_slice = slice(Z, range{i, i + iblock}, range{kwtop, ihi});
452 auto WV_slice =
453 slice(WV, range{0, nrows(Z_slice)}, range{0, ncols(Z_slice)});
456 i = i + iblock;
457 }
458 }
459} // namespace tlapack
460
461} // namespace tlapack
462
463#endif // TLAPACK_AED_GENERALIZED_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:183
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
constexpr real_type< T > abs1(const T &x)
1-norm absolute value, |Re(x)| + |Im(x)|
Definition utils.hpp:133
#define TLAPACK_SVECTOR
Macro for tlapack::concepts::SliceableVector compatible with C++17.
Definition concepts.hpp:909
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
int multishift_qz(bool want_s, bool want_q, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, matrix_t &A, matrix_t &B, alpha_t &alpha, beta_t &beta, matrix_t &Q, matrix_t &Z, FrancisOpts &opts)
multishift_qz computes the eigenvalues of a matrix pair (H,T), where H is an upper Hessenberg matrix ...
Definition multishift_qz.hpp:64
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
int generalized_schur_move(bool want_q, bool want_z, matrix_t &A, matrix_t &B, matrix_t &Q, matrix_t &Z, size_type< matrix_t > &ifst, size_type< matrix_t > &ilst)
generalized_schur_move reorders the generalized Schur factorization of a pencil ( S,...
Definition generalized_schur_move.hpp:53
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
int lahqz(bool want_s, bool want_q, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, matrix_t &A, matrix_t &B, alpha_t &alpha, beta_t &beta, matrix_t &Q, matrix_t &Z)
lahqz computes the eigenvalues of a matrix pair (H,T), where H is an upper Hessenberg matrix and T is...
Definition lahqz.hpp:60
void rotg(T &a, T &b, T &c, T &s)
Construct plane rotation that eliminates b, such that:
Definition rotg.hpp:39
void rot(vectorX_t &x, vectorY_t &y, const c_type &c, const s_type &s)
Apply plane rotation:
Definition rot.hpp:44
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
int gghrd(bool wantq, bool wantz, size_type< A_t > ilo, size_type< A_t > ihi, A_t &A, B_t &B, Q_t &Q, Z_t &Z)
Reduces a pair of real square matrices (A, B) to generalized upper Hessenberg form using unitary tran...
Definition gghrd.hpp:42
void aggressive_early_deflation_generalized(bool want_s, bool want_q, bool want_z, size_type< matrix_t > ilo, size_type< matrix_t > ihi, size_type< matrix_t > nw, matrix_t &A, matrix_t &B, alpha_t &alpha, beta_t &beta, matrix_t &Q, matrix_t &Z, size_type< matrix_t > &ns, size_type< matrix_t > &nd, FrancisOpts &opts)
aggressive_early_deflation_generalized accepts as input an upper Hessenberg pencil (A,...
Definition aggressive_early_deflation_generalized.hpp:100
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
void lahqz_eig22(const A_t &A, const B_t &B, complex_type< T > &alpha1, complex_type< T > &alpha2, T &beta1, T &beta2)
Computes the generalized eigenvalues of a 2x2 pencil (A,B) with B upper triangular.
Definition lahqz_eig22.hpp:35
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Options struct for multishift_qr().
Definition FrancisOpts.hpp:23