<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
hetf3.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_HETF3_HH
13#define TLAPACK_HETF3_HH
14
16#include "tlapack/blas/copy.hpp"
17#include "tlapack/blas/gemv.hpp"
20#include "tlapack/blas/swap.hpp"
24
25namespace tlapack {
27struct BlockedLDLOpts : public EcOpts {
28 constexpr BlockedLDLOpts(const EcOpts& opts = {}) : EcOpts(opts){};
29
30 size_t nb = 32;
31 Op invariant = Op::Trans;
32};
33
54template <TLAPACK_UPLO uplo_t,
59 matrix_t& A,
60 ipiv_t& ipiv,
61 work_t& work,
62 const BlockedLDLOpts& opts)
63{
64 using T = type_t<matrix_t>;
65 using real_t = real_type<T>;
66 using idx_t = size_type<matrix_t>;
68
69 // Constants
70 const idx_t n = nrows(A);
71 const idx_t nb = opts.nb;
72 const bool hermitian = Op::ConjTrans == opts.invariant;
73 // Initialize ALPHA for use in choosing pivot block size.
74 const real_t alpha = (real_t(1) + sqrt(real_t(17))) / real_t(8);
75
76 // check arguments
77 tlapack_check(uplo == Uplo::Lower || uplo == Uplo::Upper);
78 tlapack_check(nrows(A) == ncols(A));
79 tlapack_check(nrows(A) == size(ipiv));
80 tlapack_check(opts.invariant == Op::Trans ||
81 opts.invariant == Op::ConjTrans);
82
83 // Quick return
84 if (n <= 0) return 0;
85
86 // These are QoL wrappers for passing non-const r-value slice references to
87 // some functions, to avoid temporary variable declaration clutter.
88 // TODO: Add overloaded definitions to the original functions and remove
89 // these workarounds.
90 constexpr auto copy_ = [](auto&& x, auto&& y) {
91 return tlapack::copy(x, y);
92 };
93 constexpr auto gemv_ = [](auto&& trans, auto&& alpha, auto&& A, auto&& x,
94 auto&& beta, auto&& y) {
95 return tlapack::gemv(trans, alpha, A, x, beta, y);
96 };
97 constexpr auto swap_ = [](auto&& x, auto&& y) {
98 return tlapack::swap(x, y);
99 };
100 constexpr auto rscl_ = [](auto&& alpha, auto&& x) {
101 return tlapack::rscl(alpha, x);
102 };
103
104 // This is a helper for conjugating a vector as LACGV is currently
105 // unimplemented.
106 auto conjv = [](auto&& x) {
107 const idx_t n = size(x);
108 for (int i = 0; i < n; ++i)
109 x[i] = conj(x[i]);
110 };
111
112 int info = 0;
113 if (uplo == Uplo::Upper) {
114 // Factorize the trailing nb columns of A using its upper triangle and
115 // working backwards, while computing the matrix W = U12*D. which will
116 // be used to update the upper-left block A11 of A.
117 auto [W, work2] = reshape(work, n, nb);
118 int jn = min((int)nb, (int)n);
119 // s is the source column of the swap P_i
120 // We will use it as the main induction variable and it will decrement
121 // by 1 or 2 depending on the rank of each pivot.
122 int s;
123 for (s = n - 1; s > (int)(n)-jn; --s) {
124 // piv will be the target column of the swap P_i
125 int piv;
126 // j is the trailing column of the pivot
127 int j = s;
128 // jw is the column of W corresponding to the column j of A.
129 int jW = (int)(nb) + j - (int)(n);
130 auto Aj0 = slice(A, range{0, j + 1}, range{0, n});
131 auto Wj0 = slice(W, range{0, j + 1}, range{0, nb});
132 // Copy column w of A to column jW of W and update it.
133 copy_(col(Aj0, j), col(Wj0, jW));
134 if (hermitian) W(j, jW) = real(W(j, jW));
135 if (j + 1 < n) { // update
136 if (hermitian) {
137 copy_(slice(W, j, range{jW + 1, nb}),
138 slice(W, range{j + 1, n}, jW));
139 conjv(slice(W, range{j + 1, n}, jW));
140 gemv_(NO_TRANS, T(-1), cols(Aj0, range{j + 1, n}),
141 slice(W, range{j + 1, n}, jW), T(1), col(Wj0, jW));
142 W(j, jW) = real(W(j, jW));
143 }
144 else {
145 gemv_(NO_TRANS, T(-1), cols(Aj0, range{j + 1, n}),
146 slice(W, j, range{jW + 1, nb}), T(1), col(Wj0, jW));
147 }
148 }
149 // Determine the rank of the pivot and which columns and rows to
150 // swap.
151 auto abs_Ajj = abs1(W(j, jW));
152 // i_colmax is the index of the largest off-diagonal entry of the
153 // updated column j.
154 auto i_colmax = tlapack::iamax(slice(W, range{0, j}, jW));
155 auto colmax = abs1(W(i_colmax, jW));
156 if (max(colmax, abs_Ajj) == 0) {
157 piv = j;
158 info = (info == 0) ? j + 1 : info;
159 if (hermitian) A(j, j) = real(A(j, j));
160 }
161 else {
162 if (abs_Ajj >= alpha * colmax) {
163 piv = j;
164 }
165 else {
166 // Copy and update column i_colmax into column jW-1 of W.
167 copy_(slice(A, range{0, i_colmax + 1}, i_colmax),
168 slice(W, range{0, i_colmax + 1}, jW - 1));
169 copy_(slice(A, i_colmax, range{i_colmax + 1, j + 1}),
170 slice(W, range{i_colmax + 1, j + 1}, jW - 1));
171 if (hermitian) {
172 W(i_colmax, jW - 1) = real(W(i_colmax, jW - 1));
173 conjv(slice(W, range{i_colmax + 1, j + 1}, jW - 1));
174 }
175 if (j + 1 < n) {
176 if (hermitian) {
177 copy_(slice(W, i_colmax, range{jW + 1, nb}),
178 slice(W, range{j + 1, n}, jW - 1));
179 conjv(slice(W, range{j + 1, n}, jW - 1));
180 gemv_(NO_TRANS, T(-1), cols(Aj0, range{j + 1, n}),
181 slice(W, range{j + 1, n}, jW - 1), T(1),
182 col(Wj0, jW - 1));
183 W(i_colmax, jW - 1) = real(W(i_colmax, jW - 1));
184 }
185 else {
186 gemv_(NO_TRANS, T(-1), cols(Aj0, range{j + 1, n}),
187 slice(W, i_colmax, range{jW + 1, nb}), T(1),
188 col(Wj0, jW - 1));
189 }
190 }
191 // i_rowmax is the index of the largest off-diagonal entry
192 // of the updated row i_colmax.
193 auto i_rowmax = i_colmax + 1 +
194 tlapack::iamax(slice(
195 W, range{i_colmax + 1, j + 1}, jW - 1));
196 auto rowmax = abs1(W(i_rowmax, jW - 1));
197 if (i_colmax > 0) {
199 slice(W, range{0, i_colmax}, jW - 1));
200 rowmax = max(rowmax, abs1(W(i_rowmax, jW - 1)));
201 }
202 if (abs_Ajj >= alpha * colmax * colmax / rowmax) {
203 piv = j;
204 }
205 else if (abs1(W(i_colmax, jW - 1)) >= alpha * rowmax) {
206 // We will use updated column i_colmax as a rank 1
207 // pivot. Copy it over column jW of W.
208 piv = i_colmax;
209 copy_(col(Wj0, jW - 1), col(Wj0, jW));
210 }
211 else {
212 // We will use updated column i_colmax as a rank 2
213 // pivot. Decrement s to index the new leading column of
214 // the pivot.
215 piv = i_colmax;
216 --s;
217 }
218 }
219 if (piv != s) {
220 // Swap rows and columns s and piv.
221 // Their symmetric storage intersects so care must be taken
222 // not to overwrite elements out of order. First copy
223 // non-updated column s of A to column piv of S.
224 A(piv, piv) = hermitian ? real(A(s, s)) : A(s, s);
225 copy_(slice(A, range{piv + 1, s}, s),
226 slice(A, piv, range{piv + 1, s}));
227 if (hermitian) conjv(slice(A, piv, range{piv + 1, s}));
228 if (piv > 0)
229 copy_(slice(A, range{0, piv}, s),
230 slice(A, range{0, piv}, piv));
231 // Swap the non-updated rows,
232 // except the block diagonal which will be later
233 // overwritten.
234 if (j + 1 < n)
235 swap_(slice(A, piv, range{j + 1, n}),
236 slice(A, s, range{j + 1, n}));
237 // Swap the updated rows in W.
238 int sW = (int)(nb) + s - (int)(n);
239 swap_(slice(W, piv, range{sW, nb}),
240 slice(W, s, range{sW, nb}));
241 }
242 if (j == s) {
243 // Rank 1 pivot: column jW of W
244 // now holds the factor U_jD_j.
245 // copy the column from W to the column j of A,
246 // and rescale the diagonal by D_j.
247 copy_(col(Wj0, jW), col(Aj0, j));
248 if (j > 0)
249 rscl_(hermitian ? real(A(j, j)) : A(j, j),
250 slice(A, range{0, j}, j));
251 }
252 else {
253 // Rank 2 pivot: columns jW-1:jW of W
254 // now hold the factor W_j = U_jD_j,
255 // where D_j is a symmetric/Hermitian 2-by-2 block.
256 // Write W_jD_j^{-1} onto the columns j-1:j of A,
257 // and copy the upper triangle of D_j onto the diagonal of
258 // A. We use an optimized 2-by-2 inversion algorithm that
259 // minimizes the number of operations by pulling out factors
260 // to set the antidiagonals to -1.
261 T D21 = W(j - 1, jW);
262 T D11 = W(j, jW) / (hermitian ? conj(D21) : D21);
263 T D22 = W(j - 1, jW - 1) / D21;
264 if (hermitian) {
265 real_t d = real_t(1) / (real(D11 * D22) - real_t(1));
266 D21 = d / D21;
267 }
268 else {
269 T d = T(1) / (D11 * D22 - T(1));
270 D21 = d / D21;
271 }
272 for (int k = 0; k < s; ++k) {
273 A(k, j - 1) = D21 * (D11 * W(k, jW - 1) - W(k, jW));
274 A(k, j) = (hermitian ? conj(D21) : D21) *
275 (D22 * W(k, jW) - W(k, jW - 1));
276 }
277 A(j, j) = W(j, jW);
278 A(j - 1, j) = W(j - 1, jW);
279 A(j - 1, j - 1) = W(j - 1, jW - 1);
280 }
281 }
282 // Update ipiv record
283 if (j == s) {
284 // rank 1 pivot swapping j with piv
285 ipiv[j] = piv;
286 }
287 else {
288 // Rank 2 pivot swapping s with piv and fixing j.
289 // Offset by -1 to avoid overlapping piv == 0 == -0.
290 ipiv[j] = ipiv[s] = (-piv) - 1;
291 }
292 }
293 if (s >= (int)n - jn) {
294 // The last column to be pivoted was not the last column of the
295 // block. Indicate its index in the leading block position of ipiv.
296 ipiv[n - jn] = n + s;
297 }
298 if (s >= 0) {
299 // Update A11 with the Schur complement of D:
300 // $A11 = A11 - A12 D^{-1} A12^{op}$
301 // $= A11 - W U12^{op}$.
302 auto sW = (int)(nb) + s - (int)(n);
303 const auto& U12 = slice(A, range{0, s + 1}, range{s + 1, n});
304 const auto& W12 = slice(W, range{0, s + 1}, range{sW + 1, nb});
305 auto A11 = slice(A, range{0, s + 1}, range{0, s + 1});
306 if (hermitian)
308 real_t(1), A11);
309 else
311 real_t(1), A11);
312 }
313 // Put U12 in standard form by partially undoing the swaps done to its
314 // rows, in the trailing columns of A, by looping back through them in
315 // reverse order.
316 for (int j = s + 1; j < n; ++j) {
317 int s = j;
318 int piv = ipiv[j];
319 if (piv < 0) {
320 piv = -piv - 1;
321 ++j;
322 }
323 // Swap the trailing columns of rows s and piv of A,
324 // excluding the block-diagonal pivot block.
325 if ((piv != s) & (j + 1 < n)) {
326 swap_(slice(A, piv, range{j + 1, n}),
327 slice(A, s, range{j + 1, n}));
328 }
329 }
330 }
331 else {
332 // Factorize the leading nb columns of A using its lower triangle and
333 // working forwards, while computing the matrix W = L21*D. which will be
334 // used to update the lower-right block A22 of A. We proceed in exactly
335 // the same way as the Upper case, if A is considered as reflected
336 // accross both diagonals. There is no jW index because columns of A and
337 // W now count from the same base.
338 auto [W, work2] = reshape(work, n, nb);
339 int jn = min(nb, n);
340 int s;
341 for (s = 0; s < jn - 1; ++s) {
342 int piv;
343 int j = s;
344 auto Aj0 = slice(A, range{j, n}, range{0, n});
345 auto Wj0 = slice(W, range{j, n}, range{0, nb});
346 copy_(col(Aj0, j), col(Wj0, j));
347 if (hermitian) W(j, j) = real(W(j, j));
348 if (j > 0) {
349 if (hermitian) {
350 copy_(slice(W, j, range{0, j}), slice(W, range{0, j}, j));
351 conjv(slice(W, range{0, j}, j));
352 gemv_(NO_TRANS, T(-1), cols(Aj0, range{0, j}),
353 slice(W, range{0, j}, j), T(1), col(Wj0, j));
354 W(j, j) = real(W(j, j));
355 }
356 else {
357 gemv_(NO_TRANS, T(-1), cols(Aj0, range{0, j}),
358 slice(W, j, range{0, j}), T(1), col(Wj0, j));
359 }
360 }
361 auto abs_Ajj = abs1(W(j, j));
362 auto i_colmax =
363 j + 1 + tlapack::iamax(slice(W, range{j + 1, n}, j));
364 auto colmax = abs1(W(i_colmax, j));
365 if (max(colmax, abs_Ajj) == 0) {
366 piv = j;
367 info = (info == 0) ? j + 1 : info;
368 A(j, j) = real(A(j, j));
369 }
370 else {
371 if (abs_Ajj >= alpha * colmax) {
372 piv = j;
373 }
374 else {
375 copy_(slice(A, i_colmax, range{j, i_colmax}),
376 slice(W, range{j, i_colmax}, j + 1));
377 copy_(slice(A, range{i_colmax, n}, i_colmax),
378 slice(W, range{i_colmax, n}, j + 1));
379 if (hermitian) {
380 W(i_colmax, j + 1) = real(W(i_colmax, j + 1));
381 conjv(slice(W, range{j, i_colmax}, j + 1));
382 }
383 if (j > 0) {
384 if (hermitian) {
385 copy_(slice(W, i_colmax, range{0, j}),
386 slice(W, range{0, j}, j + 1));
387 conjv(slice(W, range{0, j}, j + 1));
388 gemv_(NO_TRANS, T(-1), cols(Aj0, range{0, j}),
389 slice(W, range{0, j}, j + 1), T(1),
390 col(Wj0, j + 1));
391 W(i_colmax, j + 1) = real(W(i_colmax, j + 1));
392 }
393 else {
394 gemv_(NO_TRANS, T(-1), cols(Aj0, range{0, j}),
395 slice(W, i_colmax, range{0, j}), T(1),
396 col(Wj0, j + 1));
397 }
398 }
399 auto i_rowmax =
400 j + tlapack::iamax(slice(W, range{j, i_colmax}, j + 1));
401 auto rowmax = abs1(W(i_rowmax, j + 1));
402 if (i_colmax + 1 < n) {
403 i_rowmax = i_colmax + 1 +
405 slice(W, range{i_colmax + 1, n}, j + 1));
406 rowmax = max(rowmax, abs1(W(i_rowmax, j + 1)));
407 }
408 if (abs_Ajj >= alpha * colmax * colmax / rowmax) {
409 piv = j;
410 }
411 else if (abs1(W(i_colmax, j + 1)) >= alpha * rowmax) {
412 piv = i_colmax;
413 copy_(col(Wj0, j + 1), col(Wj0, j));
414 }
415 else {
416 ++s;
417 piv = i_colmax;
418 }
419 }
420 if (piv != s) {
421 A(piv, piv) = A(s, s);
422 copy_(slice(A, range{s + 1, piv}, s),
423 slice(A, piv, range{s + 1, piv}));
424 if (hermitian) conjv(slice(A, piv, range{s + 1, piv}));
425 if (piv + 1 < n)
426 copy_(slice(A, range{piv + 1, n}, s),
427 slice(A, range{piv + 1, n}, piv));
428 if (j > 0)
429 swap_(slice(A, piv, range{0, j}),
430 slice(A, s, range{0, j}));
431 swap_(slice(W, piv, range{0, s + 1}),
432 slice(W, s, range{0, s + 1}));
433 }
434 if (j == s) {
435 copy_(col(Wj0, j), col(Aj0, j));
436 if (j + 1 < n)
437 rscl_(hermitian ? real(W(j, j)) : W(j, j),
438 slice(A, range{j + 1, n}, j));
439 }
440 else {
441 T D21 = W(j + 1, j);
442 T D22 = W(j, j) / (hermitian ? conj(D21) : D21);
443 T D11 = W(j + 1, j + 1) / D21;
444 if (hermitian) {
445 real_t d = real_t(1) / (real(D11 * D22) - real_t(1));
446 D21 = d / D21;
447 }
448 else {
449 T d = T(1) / (D11 * D22 - T(1));
450 D21 = d / D21;
451 }
452 for (int k = j + 2; k < n; ++k) {
453 A(k, j) = (hermitian ? conj(D21) : D21) *
454 (D11 * W(k, j) - W(k, j + 1));
455 A(k, j + 1) = D21 * (D22 * W(k, j + 1) - W(k, j));
456 }
457 A(j, j) = W(j, j);
458 A(j + 1, j) = W(j + 1, j);
459 A(j + 1, j + 1) = W(j + 1, j + 1);
460 }
461 }
462 if (j == s) {
463 ipiv[j] = piv;
464 }
465 else {
466 ipiv[j] = ipiv[j + 1] = (-piv) - 1;
467 }
468 }
469 if (s < jn) {
470 ipiv[jn - 1] = n + s;
471 }
472 if (s < n) {
473 const auto& L21 = slice(A, range{s, n}, range{0, s});
474 const auto& W21 = slice(W, range{s, n}, range{0, s});
475 auto A22 = slice(A, range{s, n}, range{s, n});
476 if (hermitian)
478 real_t(1), A22);
479 else
481 real_t(1), A22);
482 }
483 for (int j = s - 1; j > 0; --j) {
484 int s = j;
485 int piv = ipiv[j];
486 if (piv < 0) {
487 piv = (-piv) - 1;
488 --j;
489 }
490 if ((piv != s) && (j > 0)) {
491 swap_(slice(A, piv, range{0, j}), slice(A, s, range{0, j}));
492 }
493 }
494 }
495 return info;
496}
497
498} // namespace tlapack
499
500#endif // TLAPACK_HETF3_HH
Op
Definition types.hpp:227
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:260
constexpr real_type< T > real(const T &x) noexcept
Extends std::real() to real datatypes.
Definition utils.hpp:71
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
constexpr real_type< T > abs1(const T &x)
1-norm absolute value, |Re(x)| + |Im(x)|
Definition utils.hpp:133
#define TLAPACK_UPLO
Macro for tlapack::concepts::Uplo compatible with C++17.
Definition concepts.hpp:942
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
void rscl(const alpha_t &alpha, vector_t &x)
Scale vector by the reciprocal of a constant, .
Definition rscl.hpp:22
void copy(const vectorX_t &x, vectorY_t &y)
Copy vector, .
Definition copy.hpp:31
void swap(vectorX_t &x, vectorY_t &y)
Swap vectors, .
Definition swap.hpp:31
size_type< vector_t > iamax(const vector_t &x, const IamaxOpts< abs_f > &opts)
Return .
Definition iamax.hpp:234
void gemv(Op trans, const alpha_t &alpha, const matrixA_t &A, const vectorX_t &x, const beta_t &beta, vectorY_t &y)
General matrix-vector multiply:
Definition gemv.hpp:57
void syr2k(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Symmetric rank-k update:
Definition syr2k.hpp:64
void her2k(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition her2k.hpp:72
int hetf3(uplo_t uplo, matrix_t &A, ipiv_t &ipiv, work_t &work, const BlockedLDLOpts &opts)
Computes the partial factorization of a symmetric or Hermitian matrix A using the Bunch-Kaufman diago...
Definition hetf3.hpp:58
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Computes the Bunch-Kaufman factorization of a symmetric or Hermitian matrix A using a blocked algorit...
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Options struct for hetrf_blocked()
Definition hetf3.hpp:27
size_t nb
Block size.
Definition hetf3.hpp:30
Options for error checking.
Definition exceptionHandling.hpp:76