<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
lu_mult.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_LU_MULT_HH
11#define TLAPACK_LU_MULT_HH
12
14#include "tlapack/blas/trmm.hpp"
15
16namespace tlapack {
17
19struct LuMultOpts {
22 size_t nx = 1;
23};
24
39template <TLAPACK_SMATRIX matrix_t>
40void lu_mult(matrix_t& A, const LuMultOpts& opts = {})
41{
42 using idx_t = size_type<matrix_t>;
43 using T = type_t<matrix_t>;
44 using range = pair<idx_t, idx_t>;
45 using real_t = real_type<T>;
46
47 const idx_t m = nrows(A);
48 const idx_t n = ncols(A);
49 tlapack_check(m == n);
50 tlapack_check(opts.nx >= 1);
51
52 // quick return
53 if (n == 0) return;
54
55 if (n <= (idx_t)opts.nx) { // Matrix is small, do not use recursion
56 for (idx_t i2 = n; i2 > 0; --i2) {
57 idx_t i = i2 - 1;
58 for (idx_t j2 = n; j2 > 0; --j2) {
59 idx_t j = j2 - 1;
60 T sum(0);
61 for (idx_t k = 0; k <= min(i, j); ++k) {
62 if (i == k)
63 sum += A(k, j);
64 else
65 sum += A(i, k) * A(k, j);
66 }
67 A(i, j) = sum;
68 }
69 }
70 return;
71 }
72
73 const idx_t n0 = n / 2;
74
75 /*
76 Matrix A is splitted into 4 submatrices:
77 A = [ A00 A01 ]
78 [ A10 A11 ]
79 and, hereafter,
80 L00 is the strict lower triangular part of A00, with unitary main
81 diagonal. L11 is the strict lower triangular part of A11, with unitary
82 main diagonal. U00 is the upper triangular part of A00. U11 is the upper
83 triangular part of A11.
84 */
85 auto A00 = slice(A, range(0, n0), range(0, n0));
86 auto A01 = slice(A, range(0, n0), range(n0, n));
87 auto A10 = slice(A, range(n0, n), range(0, n0));
88 auto A11 = slice(A, range(n0, n), range(n0, n));
89
90 lu_mult(A11, opts);
91
92 // A11 = A10*A01 + L11*U11
93 gemm(NO_TRANS, NO_TRANS, T(1), A10, A01, T(1), A11);
94
95 // A01 = L00*A01
96 trmm(LEFT_SIDE, LOWER_TRIANGLE, NO_TRANS, UNIT_DIAG, real_t(1), A00, A01);
97
98 // A10 = A10*U00
99 trmm(RIGHT_SIDE, UPPER_TRIANGLE, NO_TRANS, NON_UNIT_DIAG, real_t(1), A00,
100 A10);
101
102 // A00 = L00*U00
103 lu_mult(A00, opts);
104
105 return;
106}
107
108} // namespace tlapack
109
110#endif // TLAPACK_LU_MULT_HH
void lu_mult(matrix_t &A, const LuMultOpts &opts={})
in-place multiplication of lower triangular matrix L and upper triangular matrix U.
Definition lu_mult.hpp:40
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Options struct for lu_mult()
Definition lu_mult.hpp:19
size_t nx
Optimization parameter.
Definition lu_mult.hpp:22