<T>LAPACK 0.1.2
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gemmtr.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_GEMMTR_HH
11#define TLAPACK_GEMMTR_HH
12
14
15namespace tlapack {
16
55template <TLAPACK_MATRIX matrixA_t,
56 TLAPACK_MATRIX matrixB_t,
57 TLAPACK_MATRIX matrixC_t,
58 TLAPACK_SCALAR alpha_t,
59 TLAPACK_SCALAR beta_t,
60 class T = type_t<matrixC_t>>
62 Op transA,
63 Op transB,
64 const alpha_t& alpha,
65 const matrixA_t& A,
66 const matrixB_t& B,
67 const beta_t& beta,
68 matrixC_t& C)
69{
70 // data traits
71 using TA = type_t<matrixA_t>;
72 using TB = type_t<matrixB_t>;
73 using idx_t = size_type<matrixA_t>;
74
75 // constants
76 const idx_t n = (transB == Op::NoTrans) ? ncols(B) : nrows(B);
77 const idx_t k = (transA == Op::NoTrans) ? ncols(A) : nrows(A);
78
79 // check arguments
80 tlapack_check_false(uplo != Uplo::Upper && uplo != Uplo::Lower);
81 tlapack_check_false(transA != Op::NoTrans && transA != Op::Trans &&
82 transA != Op::ConjTrans);
83 tlapack_check_false(transB != Op::NoTrans && transB != Op::Trans &&
84 transB != Op::ConjTrans);
85 tlapack_check_false((idx_t)ncols(C) != n && (idx_t)nrows(C) != n);
87 (idx_t)((transA == Op::NoTrans) ? ncols(A) : nrows(A)) != k);
89 (idx_t)((transB == Op::NoTrans) ? nrows(B) : ncols(B)) != k);
91 (idx_t)((transA == Op::NoTrans) ? nrows(A) : ncols(A)) != n);
93 (idx_t)((transB == Op::NoTrans) ? ncols(B) : nrows(B)) != n);
94
95 // Upper Triangular
96 if (uplo == UPPER_TRIANGLE) {
97 if (transA == Op::NoTrans) {
99 if (transB == Op::NoTrans) {
100 for (idx_t j = 0; j < n; ++j) {
101 for (idx_t i = 0; i <= j; ++i)
102 C(i, j) *= beta;
103 for (idx_t l = 0; l < k; ++l) {
104 const scalar_t alphaTimesblj = alpha * B(l, j);
105 for (idx_t i = 0; i <= j; ++i)
106 C(i, j) += A(i, l) * alphaTimesblj;
107 }
108 }
109 }
110 else if (transB == Op::Trans) {
111 for (idx_t j = 0; j < n; ++j) {
112 for (idx_t i = 0; i <= j; ++i)
113 C(i, j) *= beta;
114 for (idx_t l = 0; l < k; ++l) {
115 const scalar_t alphaTimesbjl = alpha * B(j, l);
116 for (idx_t i = 0; i <= j; ++i)
117 C(i, j) += A(i, l) * alphaTimesbjl;
118 }
119 }
120 }
121 else { // transB == Op::ConjTrans
122 for (idx_t j = 0; j < n; ++j) {
123 for (idx_t i = 0; i <= j; ++i)
124 C(i, j) *= beta;
125 for (idx_t l = 0; l < k; ++l) {
126 const scalar_t alphaTimesbjl = alpha * conj(B(j, l));
127 for (idx_t i = 0; i <= j; ++i)
128 C(i, j) += A(i, l) * alphaTimesbjl;
129 }
130 }
131 }
132 }
133 else if (transA == Op::Trans) {
135
136 if (transB == Op::NoTrans) {
137 for (idx_t j = 0; j < n; ++j) {
138 for (idx_t i = 0; i <= j; ++i) {
139 scalar_t sum(0);
140 for (idx_t l = 0; l < k; ++l)
141 sum += A(l, i) * B(l, j);
142 C(i, j) = alpha * sum + beta * C(i, j);
143 }
144 }
145 }
146 else if (transB == Op::Trans) {
147 for (idx_t j = 0; j < n; ++j) {
148 for (idx_t i = 0; i <= j; ++i) {
149 scalar_t sum(0);
150 for (idx_t l = 0; l < k; ++l)
151 sum += A(l, i) * B(j, l);
152 C(i, j) = alpha * sum + beta * C(i, j);
153 }
154 }
155 }
156 else { // transB == Op::ConjTrans
157 for (idx_t j = 0; j < n; ++j) {
158 for (idx_t i = 0; i <= j; ++i) {
159 scalar_t sum(0);
160 for (idx_t l = 0; l < k; ++l)
161 sum += A(l, i) * conj(B(j, l));
162 C(i, j) = alpha * sum + beta * C(i, j);
163 }
164 }
165 }
166 }
167 else { // transA == Op::ConjTrans
168
170
171 if (transB == Op::NoTrans) {
172 for (idx_t j = 0; j < n; ++j) {
173 for (idx_t i = 0; i <= j; ++i) {
174 scalar_t sum(0);
175 for (idx_t l = 0; l < k; ++l)
176 sum += conj(A(l, i)) * B(l, j);
177 C(i, j) = alpha * sum + beta * C(i, j);
178 }
179 }
180 }
181 else if (transB == Op::Trans) {
182 for (idx_t j = 0; j < n; ++j) {
183 for (idx_t i = 0; i <= j; ++i) {
184 scalar_t sum(0);
185 for (idx_t l = 0; l < k; ++l)
186 sum += conj(A(l, i)) * B(j, l);
187 C(i, j) = alpha * sum + beta * C(i, j);
188 }
189 }
190 }
191 else { // transB == Op::ConjTrans
192 for (idx_t j = 0; j < n; ++j) {
193 for (idx_t i = 0; i <= j; ++i) {
194 scalar_t sum(0);
195 for (idx_t l = 0; l < k; ++l)
196 sum += A(l, i) * B(j, l);
197 C(i, j) = alpha * conj(sum) + beta * C(i, j);
198 }
199 }
200 }
201 }
202 }
203 else { // uplo == Uplo::Lower
204 if (transA == Op::NoTrans) {
206 if (transB == Op::NoTrans) {
207 for (idx_t j = 0; j < n; ++j) {
208 for (idx_t i = j; i < n; ++i)
209 C(i, j) *= beta;
210 for (idx_t l = 0; l < k; ++l) {
211 const scalar_t alphaTimesblj = alpha * B(l, j);
212 for (idx_t i = j; i < n; ++i)
213 C(i, j) += A(i, l) * alphaTimesblj;
214 }
215 }
216 }
217 else if (transB == Op::Trans) {
218 for (idx_t j = 0; j < n; ++j) {
219 for (idx_t i = j; i < n; ++i)
220 C(i, j) *= beta;
221 for (idx_t l = 0; l < k; ++l) {
222 const scalar_t alphaTimesbjl = alpha * B(j, l);
223 for (idx_t i = j; i < n; ++i)
224 C(i, j) += A(i, l) * alphaTimesbjl;
225 }
226 }
227 }
228 else { // transB == Op::ConjTrans
229 for (idx_t j = 0; j < n; ++j) {
230 for (idx_t i = j; i < n; ++i)
231 C(i, j) *= beta;
232 for (idx_t l = 0; l < k; ++l) {
233 const scalar_t alphaTimesbjl = alpha * conj(B(j, l));
234 for (idx_t i = j; i < n; ++i)
235 C(i, j) += A(i, l) * alphaTimesbjl;
236 }
237 }
238 }
239 }
240 else if (transA == Op::Trans) {
242
243 if (transB == Op::NoTrans) {
244 for (idx_t j = 0; j < n; ++j) {
245 for (idx_t i = j; i < n; ++i) {
246 scalar_t sum(0);
247 for (idx_t l = 0; l < k; ++l)
248 sum += A(l, i) * B(l, j);
249 C(i, j) = alpha * sum + beta * C(i, j);
250 }
251 }
252 }
253 else if (transB == Op::Trans) {
254 for (idx_t j = 0; j < n; ++j) {
255 for (idx_t i = j; i < n; ++i) {
256 scalar_t sum(0);
257 for (idx_t l = 0; l < k; ++l)
258 sum += A(l, i) * B(j, l);
259 C(i, j) = alpha * sum + beta * C(i, j);
260 }
261 }
262 }
263 else { // transB == Op::ConjTrans
264 for (idx_t j = 0; j < n; ++j) {
265 for (idx_t i = j; i < n; ++i) {
266 scalar_t sum(0);
267 for (idx_t l = 0; l < k; ++l)
268 sum += A(l, i) * conj(B(j, l));
269 C(i, j) = alpha * sum + beta * C(i, j);
270 }
271 }
272 }
273 }
274 else { // transA == Op::ConjTrans
275
277
278 if (transB == Op::NoTrans) {
279 for (idx_t j = 0; j < n; ++j) {
280 for (idx_t i = j; i < n; ++i) {
281 scalar_t sum(0);
282 for (idx_t l = 0; l < k; ++l)
283 sum += conj(A(l, i)) * B(l, j);
284 C(i, j) = alpha * sum + beta * C(i, j);
285 }
286 }
287 }
288 else if (transB == Op::Trans) {
289 for (idx_t j = 0; j < n; ++j) {
290 for (idx_t i = j; i < n; ++i) {
291 scalar_t sum(0);
292 for (idx_t l = 0; l < k; ++l)
293 sum += conj(A(l, i)) * B(j, l);
294 C(i, j) = alpha * sum + beta * C(i, j);
295 }
296 }
297 }
298 else { // transB == Op::ConjTrans
299 for (idx_t j = 0; j < n; ++j) {
300 for (idx_t i = j; i < n; ++i) {
301 scalar_t sum(0);
302 for (idx_t l = 0; l < k; ++l)
303 sum += A(l, i) * B(j, l);
304 C(i, j) = alpha * conj(sum) + beta * C(i, j);
305 }
306 }
307 }
308 }
309 }
310}
311
312} // namespace tlapack
313
314#endif // TLAPACK_GEMMTR_HH
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:186
Op
Definition types.hpp:227
Uplo
Definition types.hpp:50
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void gemmtr(Uplo uplo, Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General triangular matrix-matrix multiply:
Definition gemmtr.hpp:61
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113