<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
pbtrf_with_workspace.hpp
1
4//
5// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_PBTRF_WITH_WORKSPACE_HH
12#define TLAPACK_PBTRF_WITH_WORKSPACE_HH
13
14#include "tlapack/blas/trsm.hpp"
18
19namespace tlapack {
20
23 constexpr BlockedBandedCholeskyOpts(const EcOpts& opts = {})
24 : EcOpts(opts){};
25
26 size_t nb = 32; // Block size
27};
51template <typename uplo_t, typename matrix_t>
53 matrix_t& A,
54 size_t kd,
56{
58 using idx_t = tlapack::size_type<matrix_t>;
61
62 // maybe the kd >= 0
64 tlapack_check(nrows(A) == ncols(A));
65 tlapack_check(kd < nrows(A));
66 tlapack_check(kd >= 0);
67
69
70 const idx_t nb = opts.nb;
71
72 idx_t n = nrows(A);
73
74 std::vector<T> work_;
75 auto work = new_matrix(work_, nb, nb);
76
78
80 for (idx_t i = 0; i < n; i += nb) {
81 idx_t ib = (n < nb + i) ? n - i : nb;
82
83 auto A00 = slice(A, range(i, min(ib + i, n)),
84 range(i, std::min(i + ib, n)));
85
86 potrf(uplo, A00);
87
88 if (i + ib < n) {
89 // i2 = min(kd - ib, n - i - ib)
90 idx_t i2 = (kd + i < n) ? kd - ib : n - i - ib;
91 // i3 = min(ib, n-i-kd)
92 idx_t i3 = (n > i + kd) ? min<idx_t>(ib, n - i - kd) : 0;
93
94 if (i2 > 0) {
95 auto A01 = slice(A, range(i, ib + i),
96 range(i + ib, std::min(i + ib + i2, n)));
97
100 real_t(1), A00, A01);
101
102 auto A11 =
103 slice(A, range(i + ib, std::min<idx_t>(i + kd, n)),
104 range(i + ib, std::min<idx_t>(i + kd, n)));
105
107 real_t(-1), A01, real_t(1), A11);
108 }
109
110 if (i3 > 0) {
111 auto A02 =
112 slice(A, range(i, i + ib), range(i + kd, i + kd + i3));
113
114 auto work02 = slice(work, range(0, ib), range(0, i3));
115
116 for (idx_t jj = 0; jj < i3; jj++)
117 for (idx_t ii = jj; ii < ib; ++ii)
118 work02(ii, jj) = A02(ii, jj);
119
122 real_t(1), A00, work02);
123
124 auto A12 =
125 slice(A, range(i + ib, i + kd),
126 range(i + kd, std::min<idx_t>(i + kd + i3, n)));
127
128 auto A01 =
129 slice(A, range(i, ib + i),
130 range(i + ib, std::min<idx_t>(i + ib + i2, n)));
131
133 real_t(-1), A01, work02, real_t(1), A12);
134
135 auto A22 =
136 slice(A, range(i + kd, std::min<idx_t>(i + kd + i3, n)),
137 range(i + kd, std::min<idx_t>(i + kd + i3, n)));
138
140 real_t(-1), work02, real_t(1), A22);
141
142 for (idx_t jj = 0; jj < i3; ++jj) {
143 for (idx_t ii = jj; ii < ib; ++ii) {
144 A02(ii, jj) = work02(ii, jj);
145 }
146 }
147 }
148 }
149 }
150 }
151 else { // uplo == Lower
152
153 for (idx_t i = 0; i < n; i += nb) {
154 idx_t ib = (nb + i < n) ? ib = nb : n - i;
155
156 auto A00 =
157 slice(A, range(i, i + ib), range(i, std::min(ib + i, n)));
158
160
161 if (i + ib <= n) {
162 // i2 = min(kd - ib, n - i - ib)
163 idx_t i2 = (kd + i < n) ? kd - ib : n - i - ib;
164 // i3 = min(ib, n-i-kd)
165 idx_t i3 = (n > i + kd) ? min<idx_t>(ib, n - i - kd) : 0;
166
167 if (i2 > 0) {
168 auto A10 =
169 slice(A, range(ib + i, ib + i2 + i), range(i, ib + i));
170
173 real_t(1), A00, A10);
174
175 auto A11 = slice(A, range(ib + i, ib + i2 + i),
176 range(i + ib, i + ib + i2));
177
179 A11);
180 }
181
182 if (i3 > 0) {
183 auto A10 =
184 slice(A, range(ib + i, ib + i2 + i), range(i, ib + i));
185
186 auto A20 =
187 slice(A, range(kd + i, min<idx_t>(kd + i3 + i, n)),
188 range(i, i + ib));
189
190 auto work20 = slice(work, range(0, i3), range(0, ib));
191
192 for (idx_t jj = 0; jj < ib; jj++) {
193 idx_t iiend = min(jj + 1, i3);
194 for (idx_t ii = 0; ii < iiend; ++ii) {
195 work20(ii, jj) = A20(ii, jj);
196 }
197 }
198
201
202 auto A21 = slice(A, range(kd + i, kd + i + i3),
203 range(i + ib, i + ib + i2));
204
206 real_t(-1), work20, A10, real_t(1), A21);
207
208 auto A22 = slice(A, range(kd + i, kd + i + i3),
209 range(kd + i, kd + i + i3));
210
212 real_t(1), A22);
213
214 for (idx_t jj = 0; jj < ib; jj++) {
215 idx_t iiend = min(jj + 1, i3);
216 for (idx_t ii = 0; ii < iiend; ++ii) {
217 A20(ii, jj) = work20(ii, jj);
218 }
219 }
220 }
221 }
222 }
223 }
224}
225
226} // namespace tlapack
227
228#endif // TLAPACK_PBTRF_WITH_WORKSPACE_HH
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
void pbtrf_with_workspace(uplo_t uplo, matrix_t &A, size_t kd, const BlockedBandedCholeskyOpts &opts={})
Cholesky factorization of a full, banded, n by n matrix.
Definition pbtrf_with_workspace.hpp:52
void herk(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition herk.hpp:68
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void trsm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Solve the triangular matrix-vector equation.
Definition trsm.hpp:76
int potf2(uplo_t uplo, matrix_t &A)
Computes the Cholesky factorization of a Hermitian positive definite matrix A using a level-2 algorit...
Definition potf2.hpp:58
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
int potrf(uplo_t uplo, matrix_t &A, const PotrfOpts &opts={})
Computes the Cholesky factorization of a Hermitian positive definite matrix A.
Definition potrf.hpp:78
Computes the Cholesky factorization of a Hermitian positive definite matrix A.
Sort the numbers in D in increasing order (if ID = 'I') or in decreasing order (if ID = 'D' ).
Definition arrayTraits.hpp:15
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
@ NonUnit
The main diagonal is not assumed to consist of 1's.
@ Right
right side
@ Left
left side
@ NoTrans
no transpose
@ ConjTrans
conjugate transpose
@ General
0 <= i <= m, 0 <= j <= n.
@ Upper
0 <= i <= j, 0 <= j <= n.
@ Lower
0 <= i <= m, 0 <= j <= i.
Options struct for pbtrf_with_workspace()
Definition pbtrf_with_workspace.hpp:22
Options for error checking.
Definition exceptionHandling.hpp:76