<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
lahr2.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_LAHR2_HH
13#define TLAPACK_LAHR2_HH
14
16#include "tlapack/blas/axpy.hpp"
17#include "tlapack/blas/copy.hpp"
18#include "tlapack/blas/gemm.hpp"
19#include "tlapack/blas/gemv.hpp"
20#include "tlapack/blas/scal.hpp"
21#include "tlapack/blas/trmm.hpp"
22#include "tlapack/blas/trmv.hpp"
25
26namespace tlapack {
27
57template <TLAPACK_SMATRIX matrix_t,
58 TLAPACK_VECTOR vector_t,
59 TLAPACK_SMATRIX matrixT_t,
60 TLAPACK_SMATRIX matrixY_t>
63 matrix_t& A,
65 matrixT_t& T,
66 matrixY_t& Y)
67{
68 using TA = type_t<matrix_t>;
69 using idx_t = size_type<matrix_t>;
71 using real_t = real_type<TA>;
72
73 // constants
74 const real_t one(1);
75 const idx_t n = nrows(A);
76
77 // quick return if possible
78 if (n <= 1) return 0;
79
80 TA ei(0);
81 for (idx_t i = 0; i < nb; ++i) {
82 if (i > 0) {
83 //
84 // Update A(K+1:N,I), this rest will be updated later via
85 // level 3 BLAS.
86 //
87
88 //
89 // Update I-th column of A - Y * V**T
90 // (Application of the reflectors from the right)
91 //
92 auto Y2 = slice(Y, range{k + 1, n}, range{0, i});
93 auto Vti = slice(A, k + i, range{0, i});
94 auto b = slice(A, range{k + 1, n}, i);
95 for (idx_t j = 0; j < i; ++j)
96 Vti[j] = conj(Vti[j]);
97 gemv(NO_TRANS, -one, Y2, Vti, one, b);
98 for (idx_t j = 0; j < i; ++j)
99 Vti[j] = conj(Vti[j]);
100 //
101 // Apply I - V * T**T * V**T to this column (call it b) from the
102 // left, using the last column of T as workspace
103 //
104 // Let V = ( V1 ) and b = ( b1 ) (first i rows)
105 // ( V2 ) ( b2 )
106 //
107 // where V1 is unit lower triangular
108 //
109 auto b1 = slice(b, range{0, i});
110 auto b2 = slice(b, range{i, size(b)});
111 auto V = slice(A, range{k + 1, n}, range{0, i});
112 auto V1 = slice(V, range{0, i}, range{0, i});
113 auto V2 = slice(V, range{i, nrows(V)}, range{0, i});
114 //
115 // w := V1**T * b1
116 //
117 auto w = slice(T, range{0, i}, nb - 1);
118 copy(b1, w);
120 //
121 // w := w + V2**T * b2
122 //
123 gemv(CONJ_TRANS, one, V2, b2, one, w);
124 //
125 // w := T**T * w
126 //
127 auto T2 = slice(T, range{0, i}, range{0, i});
129 //
130 // b2 := b2 - V2*w
131 //
132 gemv(NO_TRANS, -one, V2, w, one, b2);
133 //
134 // b1 := b1 - V1*w
135 //
137 axpy(-one, w, b1);
138
139 A(k + i, i - 1) = ei;
140 }
141 auto v = slice(A, range{k + i + 1, n}, i);
143
144 // larf has been edited to not require A(k+i,i) = one
145 // this is for thread safety. Since we already modified
146 // A(k+i,i) before, this is not required here
147 ei = v[0];
148 v[0] = one;
149 //
150 // Compute Y(K+1:N,I)
151 //
152 auto A2 = slice(A, range{k + 1, n}, range{i + 1, n - k});
153 auto y = slice(Y, range{k + 1, n}, i);
154 gemv(NO_TRANS, one, A2, v, y);
155 auto t = slice(T, range{0, i}, i);
156 auto A3 = slice(A, range{k + i + 1, n}, range{0, i});
157 gemv(CONJ_TRANS, one, A3, v, t);
158 auto Y2 = slice(Y, range{k + 1, n}, range{0, i});
159 gemv(NO_TRANS, -one, Y2, t, one, y);
160 scal(tau[i], y);
161 //
162 // Compute T(0:I+1,I)
163 //
164 scal(-tau[i], t);
165 auto T2 = slice(T, range{0, i}, range{0, i});
167 T(i, i) = tau[i];
168 }
169 A(k + nb, nb - 1) = ei;
170 //
171 // Compute Y(0:k+1,0:nb)
172 //
173 auto A4 = slice(A, range{0, k + 1}, range{1, nb + 1});
174 auto Y3 = slice(Y, range{0, k + 1}, range{0, nb});
175 lacpy(GENERAL, A4, Y3);
176 auto V1 = slice(A, range{k + 1, k + nb + 1}, range{0, nb});
177 auto Y1 = slice(Y, range{0, k + 1}, range{0, nb});
179 if (k + nb + 1 < n) {
180 auto A5 = slice(A, range{0, k + 1}, range{nb + 1, n - k});
181 auto V2 = slice(A, range{k + nb + 1, n}, range{0, nb});
183 }
185
186 return 0;
187}
188
189} // namespace tlapack
190
191#endif // TLAPACK_LAHR2_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:183
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:181
constexpr internal::RightSide RIGHT_SIDE
right side
Definition types.hpp:291
constexpr internal::Forward FORWARD
Forward direction.
Definition types.hpp:376
constexpr internal::UnitDiagonal UNIT_DIAG
The main diagonal is assumed to consist of 1's.
Definition types.hpp:217
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::NonUnitDiagonal NON_UNIT_DIAG
The main diagonal is not assumed to consist of 1's.
Definition types.hpp:215
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
constexpr internal::ColumnwiseStorage COLUMNWISE_STORAGE
Columnwise storage.
Definition types.hpp:409
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_VECTOR
Macro for tlapack::concepts::Vector compatible with C++17.
Definition concepts.hpp:906
int lahr2(size_type< matrix_t > k, size_type< matrix_t > nb, matrix_t &A, vector_t &tau, matrixT_t &T, matrixY_t &Y)
Reduces a general square matrix to upper Hessenberg form.
Definition lahr2.hpp:61
void larfg(storage_t storeMode, type_t< vector_t > &alpha, vector_t &x, type_t< vector_t > &tau)
Generates a elementary Householder reflection.
Definition larfg.hpp:73
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void copy(const vectorX_t &x, vectorY_t &y)
Copy vector, .
Definition copy.hpp:31
void axpy(const alpha_t &alpha, const vectorX_t &x, vectorY_t &y)
Add scaled vector, .
Definition axpy.hpp:34
void scal(const alpha_t &alpha, vector_t &x)
Scale vector by constant, .
Definition scal.hpp:30
void gemv(Op trans, const alpha_t &alpha, const matrixA_t &A, const vectorX_t &x, const beta_t &beta, vectorY_t &y)
General matrix-vector multiply:
Definition gemv.hpp:57
void trmv(Uplo uplo, Op trans, Diag diag, const matrixA_t &A, vectorX_t &x)
Triangular matrix-vector multiply:
Definition trmv.hpp:60
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113