<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
mult_hehe.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_MULT_HEHE_HH
11#define TLAPACK_MULT_HEHE_HH
12
14#include "tlapack/blas/gemm.hpp"
15#include "tlapack/blas/hemm.hpp"
17
18namespace tlapack {
19
48template <TLAPACK_SMATRIX matrixA_t,
49 TLAPACK_SMATRIX matrixB_t,
50 TLAPACK_SMATRIX matrixC_t,
51 TLAPACK_SCALAR alpha_t,
52 TLAPACK_SCALAR beta_t>
54 const alpha_t& alpha,
55 matrixA_t& A,
56 matrixB_t& B,
57 const beta_t& beta,
58 matrixC_t& C)
59{
60 // using TB = type_t<matrixB_t>;
61 using TA = type_t<matrixA_t>;
62 using TB = type_t<matrixB_t>;
63 using TC = type_t<matrixC_t>;
67
68 const idx_t m = nrows(A);
69 const idx_t n = ncols(A);
70
71 if (m != n) return;
72
73 if (n <= 1) {
74 C(0, 0) = alpha * real(A(0, 0)) * real(B(0, 0)) + beta * C(0, 0);
75 return;
76 }
77
78 const idx_t n0 = n / 2;
79
80 if (uplo == UPPER_TRIANGLE) {
81 const idx_t n0 = n / 2;
82
83 auto A00 = slice(A, range(0, n0), range(0, n0));
84 auto A01 = slice(A, range(0, n0), range(n0, n));
85 auto A11 = slice(A, range(n0, n), range(n0, n));
86
87 auto B00 = slice(B, range(0, n0), range(0, n0));
88 auto B01 = slice(B, range(0, n0), range(n0, n));
89 auto B11 = slice(B, range(n0, n), range(n0, n));
90
91 auto C00 = slice(C, range(0, n0), range(0, n0));
92 auto C01 = slice(C, range(0, n0), range(n0, n));
93 auto C10 = slice(C, range(n0, n), range(0, n0));
94 auto C11 = slice(C, range(n0, n), range(n0, n));
95
96 // A00*B00 = C00
98
99 // A01*B01^H + A00*B00 + C00 = C00
101
102 // A00*B01 + C01 = C01
104
105 // A00*B01 + C01 + A01B11 = C
107 C01);
108
109 // A11 * B01H + C10 = C10
111 C10);
112
113 // A01^H * B00 + A11*B01^H
115 C10);
116
117 // A11*B11
119
120 // A01^H * B01 + A11*B11
122
123 return;
124 }
125 else {
126 // uplo == LOWER_TRIANGLE
127 auto A00 = slice(A, range(0, n0), range(0, n0));
128 auto A10 = slice(A, range(n0, n), range(0, n0));
129 auto A11 = slice(A, range(n0, n), range(n0, n));
130
131 auto B00 = slice(B, range(0, n0), range(0, n0));
132 auto B10 = slice(B, range(n0, n), range(0, n0));
133 auto B11 = slice(B, range(n0, n), range(n0, n));
134
135 auto C00 = slice(C, range(0, n0), range(0, n0));
136 auto C01 = slice(C, range(0, n0), range(n0, n));
137 auto C10 = slice(C, range(n0, n), range(0, n0));
138 auto C11 = slice(C, range(n0, n), range(n0, n));
139
140 std::cout << std::endl;
141
142 // A00*B00 = C00
144
145 // A01^H*B10 + C00 = C00
147
148 // A10*B00 + C10 = C10
150
151 // A11*B10 + C10 = C10
153
154 // A00*B01^H + C01 = C01
156 C01);
157
158 // A01^H*B11 + C01 = C01
160 C01);
161
162 // A11*B11 = C11
164
165 // A10H*B10^H + C11 = C11
167
168 return;
169 }
170}
192template <TLAPACK_SCALAR alpha_t,
193 TLAPACK_SMATRIX matrixA_t,
194 TLAPACK_SMATRIX matrixB_t,
195 TLAPACK_SMATRIX matrixC_t>
198{
200}
201} // namespace tlapack
202#endif // TLAPACK_MULT_HEHE_HH
constexpr internal::LowerTriangle LOWER_TRIANGLE
Lower Triangle access.
Definition types.hpp:188
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:186
constexpr internal::RightSide RIGHT_SIDE
right side
Definition types.hpp:296
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:264
Uplo
Definition types.hpp:50
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:260
constexpr internal::LeftSide LEFT_SIDE
left side
Definition types.hpp:294
constexpr real_type< T > real(const T &x) noexcept
Extends std::real() to real datatypes.
Definition utils.hpp:71
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
void mult_hehe(Uplo uplo, const alpha_t &alpha, matrixA_t &A, matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian matrix-Hermitian matrix multiply:
Definition mult_hehe.hpp:53
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void hemm2(Side side, Uplo uplo, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
Hermitian matrix-Hermitian matrix multiply:
Definition hemm2.hpp:75
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43