<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
geqrt3.hpp
Go to the documentation of this file.
1
7// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
8//
9// This file is part of <T>LAPACK.
10// <T>LAPACK is free software: you can redistribute it and/or modify it under
11// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
12
13#ifndef TLAPACK_GEQRT3_HH
14#define TLAPACK_GEQRT3_HH
15
16#include "tlapack/blas/gemm.hpp"
17#include "tlapack/blas/trmm.hpp"
20
35namespace tlapack {
36template <TLAPACK_MATRIX matrix_a, TLAPACK_MATRIX matrix_h>
37
38void geqrt3(matrix_a& A, matrix_h& Tmatrix)
39{
40 using std::size_t;
41 using idx_t = size_type<matrix_a>;
43 using T = type_t<matrix_a>;
44
45 // constants
46 const idx_t m = nrows(A);
47 const idx_t n = ncols(A);
48
49 auto info = 0;
50 if (m < n) {
51 std::cout << "Error: m < n" << std::endl;
52 info = -1;
53 }
54
55 if (info != 0) {
56 return;
57 }
58
59 if (n == 1) {
60 // Turn the single column into a vector
61 auto a_vector = col(A, 0);
62
63 // Populate matrix T with an elementary reflector
65 }
66 else {
67 // Define slice sizes
68 idx_t n1 = n / 2;
69 idx_t n2 = n - n1;
70 idx_t m1 = n1;
71 idx_t m2 = n2 + n1;
72 idx_t m3 = m;
73
74 // slices
75 auto A1 = slice(A, range(0, m), range(0, n1));
76 auto A11 = slice(A, range(0, m1), range(0, n1));
77 auto A12 = slice(A, range(0, m1), range(n1, n));
78 auto A21 = slice(A, range(m1, m2), range(0, n1));
79 auto A22 = slice(A, range(m1, m2), range(n1, n));
80 auto A22_32 = slice(A, range(m1, m3), range(n1, n));
81 auto A31 = slice(A, range(m2, m3), range(0, n1));
82 auto A32 = slice(A, range(m2, m3), range(n1, n));
83 auto T11 = slice(Tmatrix, range(0, n1), range(0, n1));
84 auto T12 = slice(Tmatrix, range(0, n1), range(n1, n));
85 auto T22 = slice(Tmatrix, range(n1, n), range(n1, n));
86
87 // step 1: Compute the QR factorization of A1
88 geqrt3(A1, T11);
89
90 // step 2: Copy A12 into T12
91 // no additional flops, just copy
93
94 // step 3: T12 = A11ᴴ * T12
95
97 T12);
98
99 // step 4: T12 = T12 + (A21ᴴ * A22)
100
101 gemm(Op::ConjTrans, Op::NoTrans, T(1.0), A21, A22, T(1.0), T12);
102
103 // T12 = T12 + (A31ᴴ * A32)
104 gemm(Op::ConjTrans, Op::NoTrans, T(1.0), A31, A32, T(1.0), T12);
105
106 // step 5: T12 = T11ᴴ * T12
108 T12);
109
110 // step 6: A22 = A22 - (A21 * T12)
111 gemm(Op::NoTrans, Op::NoTrans, T(-1.0), A21, T12, T(1.0), A22);
112
113 // A32 = A32 - (A31 * T12)
114 gemm(Op::NoTrans, Op::NoTrans, T(-1.0), A31, T12, T(1.0), A32);
115
116 // step 7:T12 = A11 * T12
118 T12);
119
120 // step 8: A12 = A12 - T12
121 for (idx_t j = 0; j < n2; ++j) {
122 for (idx_t i = 0; i < m1; ++i) {
123 A12(i, j) -= T12(i, j);
124 }
125 }
126 // step 9: Compute the QR factorization of A22_32
127 geqrt3(A22_32, T22);
128
129 // step 10: manually compute T12 = A21ᴴ
130 for (idx_t j = 0; j < n2; ++j) {
131 for (idx_t i = 0; i < m1; ++i) {
132 if constexpr (is_complex<T>)
133 T12(i, j) = std::conj(A21(j, i));
134 else
135 T12(i, j) = A21(j, i);
136 }
137 }
138
139 // step 11: T12 = T12 * T22ᴴ
141 T12);
142
143 // step 12: T12 = T12 + A31ᴴ * A32
144 gemm(Op::ConjTrans, Op::NoTrans, T(1.0), A31, A32, T(1.0), T12);
145
146 // step 13: T12 = T12 * T11
148 T12);
149
150 // step 14: T12 = T12 * T22
152 T12);
153 }
154}
155} // namespace tlapack
156#endif // TLAPACK_GEQRT3_HH
void larfg(storage_t storeMode, type_t< vector_t > &alpha, vector_t &x, type_t< vector_t > &tau)
Generates a elementary Householder reflection.
Definition larfg.hpp:73
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
Recursive QR factorization using compact WY Householder representation.
Definition arrayTraits.hpp:15
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
@ Unit
The main diagonal is assumed to consist of 1's.
@ NonUnit
The main diagonal is not assumed to consist of 1's.
@ Right
right side
@ Left
left side
@ Forward
Forward direction.
@ NoTrans
no transpose
@ ConjTrans
conjugate transpose
@ Columnwise
Columnwise storage.
@ General
0 <= i <= m, 0 <= j <= n.
@ Upper
0 <= i <= j, 0 <= j <= n.
@ Lower
0 <= i <= m, 0 <= j <= i.