<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
larfb.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_LARFB_HH
13#define TLAPACK_LARFB_HH
14
16#include "tlapack/blas/gemm.hpp"
17#include "tlapack/blas/trmm.hpp"
19
20namespace tlapack {
21
65template <class T,
66 TLAPACK_SMATRIX matrixV_t,
67 TLAPACK_MATRIX matrixT_t,
68 TLAPACK_SMATRIX matrixC_t,
69 TLAPACK_SIDE side_t,
70 TLAPACK_OP trans_t,
71 TLAPACK_DIRECTION direction_t,
72 TLAPACK_STOREV storage_t>
75 direction_t direction,
77 const matrixV_t& V,
78 const matrixT_t& Tmatrix,
79 const matrixC_t& C)
80{
81 using idx_t = size_type<matrixC_t>;
83
84 // constants
85 const idx_t m = nrows(C);
86 const idx_t n = ncols(C);
87 const idx_t k = nrows(Tmatrix);
88
89 if constexpr (is_same_v<T, type_t<work_t>>)
90 return (side == Side::Left) ? WorkInfo(k, n) : WorkInfo(m, k);
91 else
92 return WorkInfo(0);
93}
94
103template <TLAPACK_SMATRIX matrixV_t,
104 TLAPACK_MATRIX matrixT_t,
105 TLAPACK_SMATRIX matrixC_t,
106 TLAPACK_WORKSPACE work_t,
107 TLAPACK_SIDE side_t,
108 TLAPACK_OP trans_t,
109 TLAPACK_DIRECTION direction_t,
110 TLAPACK_STOREV storage_t>
113 direction_t direction,
115 const matrixV_t& V,
116 const matrixT_t& Tmatrix,
117 matrixC_t& C,
118 work_t& work)
119{
120 using idx_t = size_type<matrixC_t>;
121 using T = type_t<work_t>;
122 using real_t = real_type<T>;
123
125
126 // constants
127 const real_t one(1);
128 const idx_t m = nrows(C);
129 const idx_t n = ncols(C);
130 const idx_t k = nrows(Tmatrix);
131
132 // check arguments
133 tlapack_check_false(side != Side::Left && side != Side::Right);
135 trans != Op::NoTrans && trans != Op::ConjTrans &&
136 ((trans != Op::Trans) || is_complex<type_t<matrixV_t>>));
137 tlapack_check_false(direction != Direction::Backward &&
138 direction != Direction::Forward);
139 tlapack_check_false(storeMode != StoreV::Columnwise &&
140 storeMode != StoreV::Rowwise);
142 (storeMode == StoreV::Columnwise)
143 ? ((ncols(V) == k) && (side == Side::Left) ? (nrows(V) == m)
144 : (nrows(V) == n))
145 : ((nrows(V) == k) && (side == Side::Left) ? (ncols(V) == m)
146 : (ncols(V) == n)));
147 tlapack_check(nrows(Tmatrix) == ncols(Tmatrix));
148
149 // Quick return
150 if (m <= 0 || n <= 0 || k <= 0) return 0;
151
152 // Matrix W
153 auto [W, work1] =
154 (side == Side::Left) ? reshape(work, k, n) : reshape(work, m, k);
155
156 if (storeMode == StoreV::Columnwise) {
157 if (direction == Direction::Forward) {
158 if (side == Side::Left) {
159 // W is an k-by-n matrix
160 // V is an m-by-k matrix
161
162 // Matrix views
163 const auto V1 = rows(V, range{0, k});
164 const auto V2 = rows(V, range{k, m});
165 auto C1 = rows(C, range{0, k});
166 auto C2 = rows(C, range{k, m});
167
168 // W := C1
169 lacpy(GENERAL, C1, W);
170 // W := V1^H W
172 W);
173 if (m > k)
174 // W := W + V2^H C2
176 // W := op(Tmatrix) W
178 Tmatrix, W);
179 if (m > k)
180 // C2 := C2 - V2 W
181 gemm(NO_TRANS, NO_TRANS, -one, V2, W, one, C2);
182 // W := - V1 W
184 W);
185
186 // C1 := C1 + W
187 for (idx_t j = 0; j < n; ++j)
188 for (idx_t i = 0; i < k; ++i)
189 C1(i, j) += W(i, j);
190 }
191 else { // side == Side::Right
192 // W is an m-by-k matrix
193 // V is an n-by-k matrix
194
195 // Matrix views
196 const auto V1 = rows(V, range{0, k});
197 const auto V2 = rows(V, range{k, n});
198 auto C1 = cols(C, range{0, k});
199 auto C2 = cols(C, range{k, n});
200
201 // W := C1
202 lacpy(GENERAL, C1, W);
203 // W := W V1
205 W);
206 if (n > k)
207 // W := W + C2 V2
209 // W := W op(Tmatrix)
211 Tmatrix, W);
212 if (n > k)
213 // C2 := C2 - W V2^H
215 // W := - W V1^H
217 V1, W);
218
219 // C1 := C1 + W
220 for (idx_t j = 0; j < k; ++j)
221 for (idx_t i = 0; i < m; ++i)
222 C1(i, j) += W(i, j);
223 }
224 }
225 else { // direct == Direction::Backward
226 if (side == Side::Left) {
227 // W is an k-by-n matrix
228 // V is an m-by-k matrix
229
230 // Matrix views
231 const auto V1 = rows(V, range{0, m - k});
232 const auto V2 = rows(V, range{m - k, m});
233 auto C1 = rows(C, range{0, m - k});
234 auto C2 = rows(C, range{m - k, m});
235
236 // W := C2
237 lacpy(GENERAL, C2, W);
238 // W := V2^H W
240 W);
241 if (m > k)
242 // W := W + V1^H C1
244 // W := op(Tmatrix) W
246 Tmatrix, W);
247 if (m > k)
248 // C1 := C1 - V1 W
249 gemm(NO_TRANS, NO_TRANS, -one, V1, W, one, C1);
250 // W := - V2 W
252 W);
253
254 // C2 := C2 + W
255 for (idx_t j = 0; j < n; ++j)
256 for (idx_t i = 0; i < k; ++i)
257 C2(i, j) += W(i, j);
258 }
259 else { // side == Side::Right
260 // W is an m-by-k matrix
261 // V is an n-by-k matrix
262
263 // Matrix views
264 const auto V1 = rows(V, range{0, n - k});
265 const auto V2 = rows(V, range{n - k, n});
266 auto C1 = cols(C, range{0, n - k});
267 auto C2 = cols(C, range{n - k, n});
268
269 // W := C2
270 lacpy(GENERAL, C2, W);
271 // W := W V2
273 W);
274 if (n > k)
275 // W := W + C1 V1
277 // W := W op(Tmatrix)
279 Tmatrix, W);
280 if (n > k)
281 // C1 := C1 - W V1^H
283 // W := - W V2^H
285 V2, W);
286
287 // C2 := C2 + W
288 for (idx_t j = 0; j < k; ++j)
289 for (idx_t i = 0; i < m; ++i)
290 C2(i, j) += W(i, j);
291 }
292 }
293 }
294 else { // storeV == StoreV::Rowwise
295 if (direction == Direction::Forward) {
296 if (side == Side::Left) {
297 // W is an k-by-n matrix
298 // V is an k-by-m matrix
299
300 // Matrix views
301 const auto V1 = cols(V, range{0, k});
302 const auto V2 = cols(V, range{k, m});
303 auto C1 = rows(C, range{0, k});
304 auto C2 = rows(C, range{k, m});
305
306 // W := C1
307 lacpy(GENERAL, C1, W);
308 // W := V1 W
310 W);
311 if (m > k)
312 // W := W + V2 C2
314 // W := op(Tmatrix) W
316 Tmatrix, W);
317 if (m > k)
318 // C2 := C2 - V2^H W
320 // W := - V1^H W
322 W);
323
324 // C1 := C1 - W
325 for (idx_t j = 0; j < n; ++j)
326 for (idx_t i = 0; i < k; ++i)
327 C1(i, j) += W(i, j);
328 }
329 else { // side == Side::Right
330 // W is an m-by-k matrix
331 // V is an k-by-n matrix
332
333 // Matrix views
334 const auto V1 = cols(V, range{0, k});
335 const auto V2 = cols(V, range{k, n});
336 auto C1 = cols(C, range{0, k});
337 auto C2 = cols(C, range{k, n});
338
339 // W := C1
340 lacpy(GENERAL, C1, W);
341 // W := W V1^H
343 W);
344 if (n > k)
345 // W := W + C2 V2^H
347 // W := W op(Tmatrix)
349 Tmatrix, W);
350 if (n > k)
351 // C2 := C2 - W V2
352 gemm(NO_TRANS, NO_TRANS, -one, W, V2, one, C2);
353 // W := - W V1
355 W);
356
357 // C1 := C1 + W
358 for (idx_t j = 0; j < k; ++j)
359 for (idx_t i = 0; i < m; ++i)
360 C1(i, j) += W(i, j);
361 }
362 }
363 else { // direct == Direction::Backward
364 if (side == Side::Left) {
365 // W is an k-by-n matrix
366 // V is an k-by-m matrix
367
368 // Matrix views
369 const auto V1 = cols(V, range{0, m - k});
370 const auto V2 = cols(V, range{m - k, m});
371 auto C1 = rows(C, range{0, m - k});
372 auto C2 = rows(C, range{m - k, m});
373
374 // W := C2
375 lacpy(GENERAL, C2, W);
376 // W := V2 W
378 W);
379 if (m > k)
380 // W := W + V1 C1
382 // W := op(Tmatrix) W
384 Tmatrix, W);
385 if (m > k)
386 // C1 := C1 - V1^H W
388 // W := - V2^H W
390 W);
391
392 // C2 := C2 + W
393 for (idx_t j = 0; j < n; ++j)
394 for (idx_t i = 0; i < k; ++i)
395 C2(i, j) += W(i, j);
396 }
397 else { // side == Side::Right
398 // W is an m-by-k matrix
399 // V is an k-by-n matrix
400
401 // Matrix views
402 const auto V1 = cols(V, range{0, n - k});
403 const auto V2 = cols(V, range{n - k, n});
404 auto C1 = cols(C, range{0, n - k});
405 auto C2 = cols(C, range{n - k, n});
406
407 // W := C2
408 lacpy(GENERAL, C2, W);
409 // W := W V2^H
411 W);
412 if (n > k)
413 // W := W + C1 V1^H
415 // W := W op(Tmatrix)
417 Tmatrix, W);
418 if (n > k)
419 // C1 := C1 - W V1
420 gemm(NO_TRANS, NO_TRANS, -one, W, V1, one, C1);
421 // W := - W V2
423 W);
424
425 // C2 := C2 + W
426 for (idx_t j = 0; j < k; ++j)
427 for (idx_t i = 0; i < m; ++i)
428 C2(i, j) += W(i, j);
429 }
430 }
431 }
432
433 return 0;
434}
435
505template <TLAPACK_SMATRIX matrixV_t,
506 TLAPACK_MATRIX matrixT_t,
507 TLAPACK_SMATRIX matrixC_t,
508 TLAPACK_SIDE side_t,
509 TLAPACK_OP trans_t,
510 TLAPACK_DIRECTION direction_t,
511 TLAPACK_STOREV storage_t>
514 direction_t direction,
516 const matrixV_t& V,
517 const matrixT_t& Tmatrix,
518 matrixC_t& C)
519{
520 using idx_t = size_type<matrixC_t>;
522 using T = type_t<work_t>;
523
524 // Functor
526
527 // constants
528 const idx_t m = nrows(C);
529 const idx_t n = ncols(C);
530 const idx_t k = nrows(Tmatrix);
531
532 // check arguments
533 tlapack_check_false(side != Side::Left && side != Side::Right);
534
535 // Quick return
536 if (m <= 0 || n <= 0 || k <= 0) return 0;
537
538 // Allocates workspace
541 std::vector<T> work_;
542 auto work = new_matrix(work_, workinfo.m, workinfo.n);
543
544 return larfb_work(side, trans, direction, storeMode, V, Tmatrix, C, work);
545}
546
547} // namespace tlapack
548
549#endif // TLAPACK_LARFB_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:183
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:181
constexpr internal::RightSide RIGHT_SIDE
right side
Definition types.hpp:291
constexpr internal::UnitDiagonal UNIT_DIAG
The main diagonal is assumed to consist of 1's.
Definition types.hpp:217
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::NonUnitDiagonal NON_UNIT_DIAG
The main diagonal is not assumed to consist of 1's.
Definition types.hpp:215
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
constexpr internal::LeftSide LEFT_SIDE
left side
Definition types.hpp:289
#define TLAPACK_STOREV
Macro for tlapack::concepts::StoreV compatible with C++17.
Definition concepts.hpp:936
#define TLAPACK_SIDE
Macro for tlapack::concepts::Side compatible with C++17.
Definition concepts.hpp:927
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_DIRECTION
Macro for tlapack::concepts::Direction compatible with C++17.
Definition concepts.hpp:930
#define TLAPACK_OP
Macro for tlapack::concepts::Op compatible with C++17.
Definition concepts.hpp:933
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
int larfb(side_t side, trans_t trans, direction_t direction, storage_t storeMode, const matrixV_t &V, const matrixT_t &Tmatrix, matrixC_t &C)
Applies a block reflector or its conjugate transpose to a m-by-n matrix C, from either the left or ...
Definition larfb.hpp:512
int larfb_work(side_t side, trans_t trans, direction_t direction, storage_t storeMode, const matrixV_t &V, const matrixT_t &Tmatrix, matrixC_t &C, work_t &work)
Applies a block reflector or its conjugate transpose to a m-by-n matrix C, from either the left or ...
Definition larfb.hpp:111
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
constexpr WorkInfo larfb_worksize(side_t side, trans_t trans, direction_t direction, storage_t storeMode, const matrixV_t &V, const matrixT_t &Tmatrix, const matrixC_t &C)
Worspace query of larfb()
Definition larfb.hpp:73
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
constexpr bool is_complex
True if T is a complex scalar type.
Definition scalar_type_traits.hpp:192
Output information in the workspace query.
Definition workspace.hpp:16