<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
pbtrf_with_workspace.hpp
1
4//
5// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_PBTRF_WITH_WORKSPACE_HH
12#define TLAPACK_PBTRF_WITH_WORKSPACE_HH
13
14#include "tlapack/blas/trsm.hpp"
18
19namespace tlapack {
20
23 constexpr BlockedBandedCholeskyOpts(const EcOpts& opts = {})
24 : EcOpts(opts){};
25
26 size_t nb = 32; // Block size
27};
51template <typename uplo_t, typename matrix_t>
53 matrix_t& A,
54 size_t kd,
56{
58 using idx_t = tlapack::size_type<matrix_t>;
61
62 // maybe the kd >= 0
64 tlapack_check(nrows(A) == ncols(A));
65 tlapack_check(kd < nrows(A));
66 tlapack_check(kd >= 0);
67
69
70 const idx_t nb = opts.nb;
71
72 idx_t n = nrows(A);
73
74 std::vector<T> work_;
75 auto work = new_matrix(work_, nb, nb);
76
78
80 for (idx_t i = 0; i < n; i += nb) {
81 idx_t ib = (n < nb + i) ? n - i : nb;
82
83 auto A00 = slice(A, range(i, min(ib + i, n)),
84 range(i, std::min(i + ib, n)));
85
86 potrf(uplo, A00);
87
88 if (i + ib < n) {
89 // i2 = min(kd - ib, n - i - ib)
90 idx_t i2 = (kd + i < n) ? kd - ib : n - i - ib;
91 // i3 = min(ib, n-i-kd)
92 idx_t i3 = (n > i + kd) ? min(ib, n - i - kd) : 0;
93
94 if (i2 > 0) {
95 auto A01 = slice(A, range(i, ib + i),
96 range(i + ib, std::min(i + ib + i2, n)));
97
100 real_t(1), A00, A01);
101
102 auto A11 = slice(A, range(i + ib, std::min(i + kd, n)),
103 range(i + ib, std::min(i + kd, n)));
104
106 real_t(-1), A01, real_t(1), A11);
107 }
108
109 if (i3 > 0) {
110 auto A02 =
111 slice(A, range(i, i + ib), range(i + kd, i + kd + i3));
112
113 auto work02 = slice(work, range(0, ib), range(0, i3));
114
115 for (idx_t jj = 0; jj < i3; jj++)
116 for (idx_t ii = jj; ii < ib; ++ii)
117 work02(ii, jj) = A02(ii, jj);
118
121 real_t(1), A00, work02);
122
123 auto A12 = slice(A, range(i + ib, i + kd),
124 range(i + kd, std::min(i + kd + i3, n)));
125
126 auto A01 = slice(A, range(i, ib + i),
127 range(i + ib, std::min(i + ib + i2, n)));
128
130 real_t(-1), A01, work02, real_t(1), A12);
131
132 auto A22 = slice(A, range(i + kd, std::min(i + kd + i3, n)),
133 range(i + kd, std::min(i + kd + i3, n)));
134
136 real_t(-1), work02, real_t(1), A22);
137
138 for (idx_t jj = 0; jj < i3; ++jj) {
139 for (idx_t ii = jj; ii < ib; ++ii) {
140 A02(ii, jj) = work02(ii, jj);
141 }
142 }
143 }
144 }
145 }
146 }
147 else { // uplo == Lower
148
149 for (idx_t i = 0; i < n; i += nb) {
150 idx_t ib = (nb + i < n) ? ib = nb : n - i;
151
152 auto A00 =
153 slice(A, range(i, i + ib), range(i, std::min(ib + i, n)));
154
156
157 if (i + ib <= n) {
158 // i2 = min(kd - ib, n - i - ib)
159 idx_t i2 = (kd + i < n) ? kd - ib : n - i - ib;
160 // i3 = min(ib, n-i-kd)
161 idx_t i3 = (n > i + kd) ? min(ib, n - i - kd) : 0;
162
163 if (i2 > 0) {
164 auto A10 =
165 slice(A, range(ib + i, ib + i2 + i), range(i, ib + i));
166
169 real_t(1), A00, A10);
170
171 auto A11 = slice(A, range(ib + i, ib + i2 + i),
172 range(i + ib, i + ib + i2));
173
175 A11);
176 }
177
178 if (i3 > 0) {
179 auto A10 =
180 slice(A, range(ib + i, ib + i2 + i), range(i, ib + i));
181
182 auto A20 = slice(A, range(kd + i, min(kd + i3 + i, n)),
183 range(i, i + ib));
184
185 auto work20 = slice(work, range(0, i3), range(0, ib));
186
187 for (idx_t jj = 0; jj < ib; jj++) {
188 idx_t iiend = min(jj + 1, i3);
189 for (idx_t ii = 0; ii < iiend; ++ii) {
190 work20(ii, jj) = A20(ii, jj);
191 }
192 }
193
196
197 auto A21 = slice(A, range(kd + i, kd + i + i3),
198 range(i + ib, i + ib + i2));
199
201 real_t(-1), work20, A10, real_t(1), A21);
202
203 auto A22 = slice(A, range(kd + i, kd + i + i3),
204 range(kd + i, kd + i + i3));
205
207 real_t(1), A22);
208
209 for (idx_t jj = 0; jj < ib; jj++) {
210 idx_t iiend = min(jj + 1, i3);
211 for (idx_t ii = 0; ii < iiend; ++ii) {
212 A20(ii, jj) = work20(ii, jj);
213 }
214 }
215 }
216 }
217 }
218 }
219}
220
221} // namespace tlapack
222
223#endif // TLAPACK_PBTRF_WITH_WORKSPACE_HH
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
void pbtrf_with_workspace(uplo_t uplo, matrix_t &A, size_t kd, const BlockedBandedCholeskyOpts &opts={})
Cholesky factorization of a full, banded, n by n matrix.
Definition pbtrf_with_workspace.hpp:52
void herk(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition herk.hpp:68
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void trsm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Solve the triangular matrix-vector equation.
Definition trsm.hpp:76
int potf2(uplo_t uplo, matrix_t &A)
Computes the Cholesky factorization of a Hermitian positive definite matrix A using a level-2 algorit...
Definition potf2.hpp:58
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
int potrf(uplo_t uplo, matrix_t &A, const PotrfOpts &opts={})
Computes the Cholesky factorization of a Hermitian positive definite matrix A.
Definition potrf.hpp:78
Computes the Cholesky factorization of a Hermitian positive definite matrix A.
Sort the numbers in D in increasing order (if ID = 'I') or in decreasing order (if ID = 'D' ).
Definition arrayTraits.hpp:15
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
@ NonUnit
The main diagonal is not assumed to consist of 1's.
@ Right
right side
@ Left
left side
@ NoTrans
no transpose
@ ConjTrans
conjugate transpose
@ General
0 <= i <= m, 0 <= j <= n.
@ Upper
0 <= i <= j, 0 <= j <= n.
@ Lower
0 <= i <= m, 0 <= j <= i.
Options struct for pbtrf_with_workspace()
Definition pbtrf_with_workspace.hpp:22
Options for error checking.
Definition exceptionHandling.hpp:76