<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gghd3.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_GGHD3_HH
13#define TLAPACK_GGHD3_HH
14
16#include "tlapack/blas/rot.hpp"
17#include "tlapack/blas/rotg.hpp"
20
21namespace tlapack {
22
26struct Gghd3Opts {
27 size_t nb = 32;
28};
29
48template <TLAPACK_SMATRIX A_t,
52int gghd3(bool wantq,
53 bool wantz,
56 A_t& A,
57 B_t& B,
58 Q_t& Q,
59 Z_t& Z,
60 const Gghd3Opts& opts = {})
61{
62 using idx_t = size_type<A_t>;
63 using range = pair<idx_t, idx_t>;
64 using T = type_t<A_t>;
65 using real_t = real_type<T>;
66 using r_matrix_t = real_type<A_t>;
67
68 Create<A_t> new_matrix;
69 Create<r_matrix_t> new_real_matrix;
70
71 // constants
72 const idx_t n = ncols(A);
73 const idx_t nb = opts.nb;
74 const idx_t nh = ihi - ilo - 1;
75
76 // check arguments
77 tlapack_check(ilo >= 0 && ilo < n);
78 tlapack_check(ihi > ilo && ihi <= n);
79 tlapack_check(n == nrows(A));
80 tlapack_check(n == ncols(B));
81 tlapack_check(n == nrows(B));
82 tlapack_check(n == ncols(Q));
83 tlapack_check(n == nrows(Q));
84 tlapack_check(n == ncols(Z));
85 tlapack_check(n == nrows(Z));
86
87 // Zero out lower triangle of B
88 for (idx_t j = 0; j < n; ++j)
89 for (idx_t i = j + 1; i < n; ++i)
90 B(i, j) = (T)0;
91
92 // quick return
93 if (nh <= 1) return 0;
94
95 // Locally allocate workspace for now
96 std::vector<real_t> Cl_;
97 auto Cl = new_real_matrix(Cl_, nh - 1, nb);
98 std::vector<T> Sl_;
99 auto Sl = new_matrix(Sl_, nh - 1, nb);
100 std::vector<real_t> Cr_;
101 auto Cr = new_real_matrix(Cr_, nh - 1, nb);
102 std::vector<T> Sr_;
103 auto Sr = new_matrix(Sr_, nh - 1, nb);
104
105 std::vector<T> Qt_;
106 auto Qt = new_matrix(Qt_, 2 * nb, 2 * nb);
107 std::vector<T> C_;
108 auto C = new_matrix(C_, 2 * nb, n);
109 auto D = new_matrix(C_, n, 2 * nb);
110
111 for (idx_t j = ilo; j + 2 < ihi; j = j + nb) {
112 // Number of columns to be reduced
113 idx_t nnb = std::min<idx_t>(nb, ihi - 2 - j);
114 // Number of 2*nnb x 2*nnb orthogonal factors
115 idx_t n2nb = (ihi - j - 2) / nnb - 1;
116 // Size of the last orthogonal factor
117 idx_t nblst = ihi - j - 1 - n2nb * nnb;
118
119 //
120 // Reduce panel j:j+nb
121 //
122 for (idx_t jb = 0; jb < nnb; ++jb) {
123 // Update jb-th column of the block
124 for (idx_t jbb = 0; jbb < jb; ++jbb) {
125 for (idx_t i = ihi - 1; i > j + jbb + 1; --i) {
126 real_t c = Cl(i - ilo - 2, jbb);
127 T s = Sl(i - ilo - 2, jbb);
128 T temp = c * A(i - 1, j + jb) + s * A(i, j + jb);
129 A(i, j + jb) =
130 -conj(s) * A(i - 1, j + jb) + c * A(i, j + jb);
131 A(i - 1, j + jb) = temp;
132 }
133 }
134 // Reduce column in A
135 for (idx_t i = ihi - 1; i > j + jb + 1; --i) {
136 rotg(A(i - 1, j + jb), A(i, j + jb), Cl(i - ilo - 2, jb),
137 Sl(i - ilo - 2, jb));
138 A(i, j + jb) = (T)0;
139 }
140
141 // Apply rotations to B and remove fill-in
142 auto B2 = slice(B, range(j + jb + 1, ihi), range(j + jb + 1, ihi));
143 auto clv = slice(Cl, range(j - ilo + jb, ihi - ilo - 2), jb);
144 auto slv = slice(Sl, range(j - ilo + jb, ihi - ilo - 2), jb);
145 auto crv = slice(Cr, range(j - ilo + jb, ihi - ilo - 2), jb);
146 auto srv = slice(Sr, range(j - ilo + jb, ihi - ilo - 2), jb);
147 hessenberg_rq(B2, clv, slv, crv, srv);
148 auto B3 = slice(B, range(j, j + jb + 1), range(j + jb + 1, ihi));
149 rot_sequence(RIGHT_SIDE, FORWARD, crv, srv, B3);
150 // Apply rotations to A
151 auto A2 = slice(A, range(j, ihi), range(j + jb + 1, ihi));
152 rot_sequence(RIGHT_SIDE, FORWARD, crv, srv, A2);
153 }
154
155 //
156 // Accumulate the left rotations into unitary matrices and use those to
157 // apply the rotations efficiently.
158 //
159 {
160 //
161 // Last block is treated separately
162 //
163 auto Qt2 = slice(Qt, range(0, nblst), range(0, nblst));
164 laset(GENERAL, (T)0, (T)1, Qt2);
165
166 for (idx_t jb = 0; jb < nnb; ++jb) {
167 for (idx_t i = nblst - 1; i > jb; --i) {
168 auto q1 = slice(Qt2, range(i - 1 - jb, nblst), i - 1);
169 auto q2 = slice(Qt2, range(i - 1 - jb, nblst), i);
170 rot(q1, q2, Cl(j - ilo + nnb * n2nb + i - 1, jb),
171 conj(Sl(j - ilo + nnb * n2nb + i - 1, jb)));
172 }
173 }
174
175 auto A2 = slice(A, range(ihi - nblst, ihi), range(j + nnb, n));
176 auto C2 = slice(C, range(0, nblst), range(j + nnb, n));
177 gemm(CONJ_TRANS, NO_TRANS, (T)1, Qt2, A2, C2);
178 lacpy(GENERAL, C2, A2);
179
180 if (ihi < n) {
181 auto B2 = slice(B, range(ihi - nblst, ihi), range(ihi, n));
182 auto C3 = slice(C, range(0, nblst), range(ihi, n));
183 gemm(CONJ_TRANS, NO_TRANS, (T)1, Qt2, B2, C3);
184 lacpy(GENERAL, C3, B2);
185 }
186
187 auto Q2 = cols(Q, range(ihi - nblst, ihi));
188 auto D2 = cols(D, range(0, nblst));
189 gemm(NO_TRANS, NO_TRANS, (T)1, Q2, Qt2, D2);
190 lacpy(GENERAL, D2, Q2);
191 }
192 for (idx_t ib = n2nb - 1; ib != (idx_t)-1; ib--) {
193 auto Qt2 = slice(Qt, range(0, 2 * nnb), range(0, 2 * nnb));
194 laset(GENERAL, (T)0, (T)1, Qt2);
195 for (idx_t jb = 0; jb < nnb; ++jb) {
196 for (idx_t i = nnb + jb; i > jb; --i) {
197 auto q1 =
198 slice(Qt2, range(i - 1 - jb, nnb + jb + 1), i - 1);
199 auto q2 = slice(Qt2, range(i - 1 - jb, nnb + jb + 1), i);
200 rot(q1, q2, Cl(j - ilo + ib * nnb + i - 1, jb),
201 conj(Sl(j - ilo + ib * nnb + i - 1, jb)));
202 }
203 }
204
205 auto A2 =
206 slice(A, range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb),
207 range(j + nnb, n));
208 auto C2 = slice(C, range(0, 2 * nnb), range(j + nnb, n));
209 gemm(CONJ_TRANS, NO_TRANS, (T)1, Qt2, A2, C2);
210 lacpy(GENERAL, C2, A2);
211
212 if (ihi < n) {
213 auto B2 = slice(
214 B, range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb),
215 range(ihi, n));
216 auto C3 = slice(C, range(0, 2 * nnb), range(ihi, n));
217 gemm(CONJ_TRANS, NO_TRANS, (T)1, Qt2, B2, C3);
218 lacpy(GENERAL, C3, B2);
219 }
220
221 auto Q2 =
222 cols(Q, range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb));
223 auto D2 = cols(D, range(0, 2 * nnb));
224 gemm(NO_TRANS, NO_TRANS, (T)1, Q2, Qt2, D2);
225 lacpy(GENERAL, D2, Q2);
226 }
227
228 //
229 // Accumulate the right rotations into unitary matrices and use those to
230 // apply the rotations efficiently.
231 //
232 {
233 // Last block is treated separately
234 auto Qt2 = slice(Qt, range(0, nblst), range(0, nblst));
235 laset(GENERAL, (T)0, (T)1, Qt2);
236
237 for (idx_t jb = 0; jb < nnb; ++jb) {
238 for (idx_t i = nblst - 1; i > jb; --i) {
239 auto q1 = slice(Qt2, range(i - 1 - jb, nblst), i - 1);
240 auto q2 = slice(Qt2, range(i - 1 - jb, nblst), i);
241 rot(q1, q2, Cr(j - ilo + nnb * n2nb + i - 1, jb),
242 conj(Sr(j - ilo + nnb * n2nb + i - 1, jb)));
243 }
244 }
245
246 if (j > 0) {
247 auto A2 = slice(A, range(0, j), range(ihi - nblst, ihi));
248 auto D2 = slice(D, range(0, j), range(0, nblst));
249 gemm(NO_TRANS, NO_TRANS, (T)1, A2, Qt2, D2);
250 lacpy(GENERAL, D2, A2);
251
252 auto B2 = slice(B, range(0, j), range(ihi - nblst, ihi));
253 gemm(NO_TRANS, NO_TRANS, (T)1, B2, Qt2, D2);
254 lacpy(GENERAL, D2, B2);
255 }
256
257 auto Z2 = cols(Z, range(ihi - nblst, ihi));
258 auto D2 = cols(D, range(0, nblst));
259 gemm(NO_TRANS, NO_TRANS, (T)1, Z2, Qt2, D2);
260 lacpy(GENERAL, D2, Z2);
261 }
262 for (idx_t ib = n2nb - 1; ib != (idx_t)-1; ib--) {
263 auto Qt2 = slice(Qt, range(0, 2 * nnb), range(0, 2 * nnb));
264 laset(GENERAL, (T)0, (T)1, Qt2);
265 for (idx_t jb = 0; jb < nnb; ++jb) {
266 for (idx_t i = nnb + jb; i > jb; --i) {
267 auto q1 =
268 slice(Qt2, range(i - 1 - jb, nnb + jb + 1), i - 1);
269 auto q2 = slice(Qt2, range(i - 1 - jb, nnb + jb + 1), i);
270 rot(q1, q2, Cr(j - ilo + ib * nnb + i - 1, jb),
271 conj(Sr(j - ilo + ib * nnb + i - 1, jb)));
272 }
273 }
274
275 if (j > 0) {
276 auto A2 =
277 slice(A, range(0, j),
278 range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb));
279 auto D2 = slice(D, range(0, j), range(0, 2 * nnb));
280 gemm(NO_TRANS, NO_TRANS, (T)1, A2, Qt2, D2);
281 lacpy(GENERAL, D2, A2);
282
283 auto B2 =
284 slice(B, range(0, j),
285 range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb));
286 gemm(NO_TRANS, NO_TRANS, (T)1, B2, Qt2, D2);
287 lacpy(GENERAL, D2, B2);
288 }
289
290 auto Z2 =
291 cols(Z, range(j + 1 + nnb * ib, j + 1 + nnb * ib + 2 * nnb));
292 auto D2 = cols(D, range(0, 2 * nnb));
293 gemm(NO_TRANS, NO_TRANS, (T)1, Z2, Qt2, D2);
294 lacpy(GENERAL, D2, Z2);
295 }
296 }
297
298 return 0;
299}
300
301} // namespace tlapack
302
303#endif // TLAPACK_GGHD3_HH
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void rotg(T &a, T &b, T &c, T &s)
Construct plane rotation that eliminates b, such that:
Definition rotg.hpp:39
void rot(vectorX_t &x, vectorY_t &y, const c_type &c, const s_type &s)
Apply plane rotation:
Definition rot.hpp:44
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void hessenberg_rq(T_t &T, CL_t &cl, SL_t &sl, CR_t &cr, SR_t &sr)
Applies a sequence of rotations to an upper triangular matrix T from the left (making it an upper Hes...
Definition hessenberg_rq.hpp:49
int rot_sequence(side_t side, direction_t direction, const C_t &c, const S_t &s, A_t &A)
Applies a sequence of plane rotations to an (m-by-n) matrix.
Definition rot_sequence.hpp:81
int gghd3(bool wantq, bool wantz, size_type< A_t > ilo, size_type< A_t > ihi, A_t &A, B_t &B, Q_t &Q, Z_t &Z, const Gghd3Opts &opts={})
Reduces a pair of real square matrices (A, B) to generalized upper Hessenberg form using unitary tran...
Definition gghd3.hpp:52
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Options struct for gghd3.
Definition gghd3.hpp:26
size_t nb
Block size.
Definition gghd3.hpp:27