<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_BLAS_GEMM_HH
12#define TLAPACK_BLAS_GEMM_HH
13
15
16namespace tlapack {
17
50template <TLAPACK_MATRIX matrixA_t,
51 TLAPACK_MATRIX matrixB_t,
52 TLAPACK_MATRIX matrixC_t,
53 TLAPACK_SCALAR alpha_t,
54 TLAPACK_SCALAR beta_t,
55 class T = type_t<matrixC_t>,
56 disable_if_allow_optblas_t<pair<matrixA_t, T>,
57 pair<matrixB_t, T>,
58 pair<matrixC_t, T>,
59 pair<alpha_t, T>,
60 pair<beta_t, T> > = 0>
62 Op transB,
63 const alpha_t& alpha,
64 const matrixA_t& A,
65 const matrixB_t& B,
66 const beta_t& beta,
67 matrixC_t& C)
68{
69 // data traits
70 using TA = type_t<matrixA_t>;
71 using TB = type_t<matrixB_t>;
72 using idx_t = size_type<matrixA_t>;
73
74 // constants
75 const idx_t m = (transA == Op::NoTrans) ? nrows(A) : ncols(A);
76 const idx_t n = (transB == Op::NoTrans) ? ncols(B) : nrows(B);
77 const idx_t k = (transA == Op::NoTrans) ? ncols(A) : nrows(A);
78
79 // check arguments
80 tlapack_check_false(transA != Op::NoTrans && transA != Op::Trans &&
81 transA != Op::ConjTrans);
82 tlapack_check_false(transB != Op::NoTrans && transB != Op::Trans &&
83 transB != Op::ConjTrans);
84 tlapack_check_false((idx_t)nrows(C) != m);
85 tlapack_check_false((idx_t)ncols(C) != n);
87 (idx_t)((transB == Op::NoTrans) ? nrows(B) : ncols(B)) != k);
88
89 if (transA == Op::NoTrans) {
91
92 if (transB == Op::NoTrans) {
93 for (idx_t j = 0; j < n; ++j) {
94 for (idx_t i = 0; i < m; ++i)
95 C(i, j) *= beta;
96 for (idx_t l = 0; l < k; ++l) {
97 const scalar_t alphaTimesblj = alpha * B(l, j);
98 for (idx_t i = 0; i < m; ++i)
99 C(i, j) += A(i, l) * alphaTimesblj;
100 }
101 }
102 }
103 else if (transB == Op::Trans) {
104 for (idx_t j = 0; j < n; ++j) {
105 for (idx_t i = 0; i < m; ++i)
106 C(i, j) *= beta;
107 for (idx_t l = 0; l < k; ++l) {
108 const scalar_t alphaTimesbjl = alpha * B(j, l);
109 for (idx_t i = 0; i < m; ++i)
110 C(i, j) += A(i, l) * alphaTimesbjl;
111 }
112 }
113 }
114 else { // transB == Op::ConjTrans
115 for (idx_t j = 0; j < n; ++j) {
116 for (idx_t i = 0; i < m; ++i)
117 C(i, j) *= beta;
118 for (idx_t l = 0; l < k; ++l) {
119 const scalar_t alphaTimesbjl = alpha * conj(B(j, l));
120 for (idx_t i = 0; i < m; ++i)
121 C(i, j) += A(i, l) * alphaTimesbjl;
122 }
123 }
124 }
125 }
126 else if (transA == Op::Trans) {
128
129 if (transB == Op::NoTrans) {
130 for (idx_t j = 0; j < n; ++j) {
131 for (idx_t i = 0; i < m; ++i) {
132 scalar_t sum(0);
133 for (idx_t l = 0; l < k; ++l)
134 sum += A(l, i) * B(l, j);
135 C(i, j) = alpha * sum + beta * C(i, j);
136 }
137 }
138 }
139 else if (transB == Op::Trans) {
140 for (idx_t j = 0; j < n; ++j) {
141 for (idx_t i = 0; i < m; ++i) {
142 scalar_t sum(0);
143 for (idx_t l = 0; l < k; ++l)
144 sum += A(l, i) * B(j, l);
145 C(i, j) = alpha * sum + beta * C(i, j);
146 }
147 }
148 }
149 else { // transB == Op::ConjTrans
150 for (idx_t j = 0; j < n; ++j) {
151 for (idx_t i = 0; i < m; ++i) {
152 scalar_t sum(0);
153 for (idx_t l = 0; l < k; ++l)
154 sum += A(l, i) * conj(B(j, l));
155 C(i, j) = alpha * sum + beta * C(i, j);
156 }
157 }
158 }
159 }
160 else { // transA == Op::ConjTrans
161
163
164 if (transB == Op::NoTrans) {
165 for (idx_t j = 0; j < n; ++j) {
166 for (idx_t i = 0; i < m; ++i) {
167 scalar_t sum(0);
168 for (idx_t l = 0; l < k; ++l)
169 sum += conj(A(l, i)) * B(l, j);
170 C(i, j) = alpha * sum + beta * C(i, j);
171 }
172 }
173 }
174 else if (transB == Op::Trans) {
175 for (idx_t j = 0; j < n; ++j) {
176 for (idx_t i = 0; i < m; ++i) {
177 scalar_t sum(0);
178 for (idx_t l = 0; l < k; ++l)
179 sum += conj(A(l, i)) * B(j, l);
180 C(i, j) = alpha * sum + beta * C(i, j);
181 }
182 }
183 }
184 else { // transB == Op::ConjTrans
185 for (idx_t j = 0; j < n; ++j) {
186 for (idx_t i = 0; i < m; ++i) {
187 scalar_t sum(0);
188 for (idx_t l = 0; l < k; ++l)
189 sum += A(l, i) * B(j, l); // little improvement here
190 C(i, j) = alpha * conj(sum) + beta * C(i, j);
191 }
192 }
193 }
194 }
195}
196
197#ifdef TLAPACK_USE_LAPACKPP
198
215template <TLAPACK_LEGACY_MATRIX matrixA_t,
216 TLAPACK_LEGACY_MATRIX matrixB_t,
217 TLAPACK_LEGACY_MATRIX matrixC_t,
218 TLAPACK_SCALAR alpha_t,
219 TLAPACK_SCALAR beta_t,
220 class T = type_t<matrixC_t>,
221 enable_if_allow_optblas_t<pair<matrixA_t, T>,
222 pair<matrixB_t, T>,
223 pair<matrixC_t, T>,
224 pair<alpha_t, T>,
225 pair<beta_t, T> > = 0>
226void gemm(Op transA,
227 Op transB,
228 const alpha_t alpha,
229 const matrixA_t& A,
230 const matrixB_t& B,
231 const beta_t beta,
232 matrixC_t& C)
233{
234 // Legacy objects
235 auto A_ = legacy_matrix(A);
236 auto B_ = legacy_matrix(B);
237 auto C_ = legacy_matrix(C);
238
239 // Constants to forward
240 constexpr Layout L = layout<matrixC_t>;
241 const auto& m = C_.m;
242 const auto& n = C_.n;
243 const auto& k = (transA == Op::NoTrans) ? A_.n : A_.m;
244
245 // Warnings for NaNs and Infs
246 if (alpha == alpha_t(0))
248 -3, "Infs and NaNs in A or B will not propagate to C on output");
249 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
251 -6,
252 "Infs and NaNs in C on input will not propagate to C on output");
253
254 return ::blas::gemm((::blas::Layout)L, (::blas::Op)transA,
255 (::blas::Op)transB, m, n, k, alpha, A_.ptr, A_.ldim,
256 B_.ptr, B_.ldim, (T)beta, C_.ptr, C_.ldim);
257}
258
259#endif
260
292template <TLAPACK_MATRIX matrixA_t,
293 TLAPACK_MATRIX matrixB_t,
294 TLAPACK_MATRIX matrixC_t,
295 TLAPACK_SCALAR alpha_t>
297 Op transB,
298 const alpha_t& alpha,
299 const matrixA_t& A,
300 const matrixB_t& B,
301 matrixC_t& C)
302{
303 return gemm(transA, transB, alpha, A, B, StrongZero(), C);
304}
305
306} // namespace tlapack
307
308#endif // #ifndef TLAPACK_BLAS_GEMM_HH
Op
Definition types.hpp:222
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Op.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43