11#ifndef TLAPACK_BLAS_GEMM_HH
12#define TLAPACK_BLAS_GEMM_HH
55 class T = type_t<matrixC_t>,
56 disable_if_allow_optblas_t<pair<matrixA_t, T>,
60 pair<beta_t, T> > = 0>
75 const idx_t m = (
transA == Op::NoTrans) ? nrows(
A) : ncols(
A);
76 const idx_t n = (
transB == Op::NoTrans) ? ncols(
B) : nrows(
B);
77 const idx_t
k = (
transA == Op::NoTrans) ? ncols(
A) : nrows(
A);
87 (idx_t)((
transB == Op::NoTrans) ? nrows(
B) : ncols(
B)) !=
k);
89 if (
transA == Op::NoTrans) {
92 if (
transB == Op::NoTrans) {
93 for (idx_t j = 0; j < n; ++j) {
94 for (idx_t i = 0; i < m; ++i)
96 for (idx_t
l = 0;
l <
k; ++
l) {
98 for (idx_t i = 0; i < m; ++i)
103 else if (
transB == Op::Trans) {
104 for (idx_t j = 0; j < n; ++j) {
105 for (idx_t i = 0; i < m; ++i)
107 for (idx_t
l = 0;
l <
k; ++
l) {
109 for (idx_t i = 0; i < m; ++i)
115 for (idx_t j = 0; j < n; ++j) {
116 for (idx_t i = 0; i < m; ++i)
118 for (idx_t
l = 0;
l <
k; ++
l) {
120 for (idx_t i = 0; i < m; ++i)
126 else if (
transA == Op::Trans) {
129 if (
transB == Op::NoTrans) {
130 for (idx_t j = 0; j < n; ++j) {
131 for (idx_t i = 0; i < m; ++i) {
133 for (idx_t
l = 0;
l <
k; ++
l)
139 else if (
transB == Op::Trans) {
140 for (idx_t j = 0; j < n; ++j) {
141 for (idx_t i = 0; i < m; ++i) {
143 for (idx_t
l = 0;
l <
k; ++
l)
150 for (idx_t j = 0; j < n; ++j) {
151 for (idx_t i = 0; i < m; ++i) {
153 for (idx_t
l = 0;
l <
k; ++
l)
164 if (
transB == Op::NoTrans) {
165 for (idx_t j = 0; j < n; ++j) {
166 for (idx_t i = 0; i < m; ++i) {
168 for (idx_t
l = 0;
l <
k; ++
l)
174 else if (
transB == Op::Trans) {
175 for (idx_t j = 0; j < n; ++j) {
176 for (idx_t i = 0; i < m; ++i) {
178 for (idx_t
l = 0;
l <
k; ++
l)
185 for (idx_t j = 0; j < n; ++j) {
186 for (idx_t i = 0; i < m; ++i) {
188 for (idx_t
l = 0;
l <
k; ++
l)
197#ifdef TLAPACK_USE_LAPACKPP
220 class T = type_t<matrixC_t>,
221 enable_if_allow_optblas_t<pair<matrixA_t, T>,
225 pair<beta_t, T> > = 0>
235 auto A_ = legacy_matrix(A);
236 auto B_ = legacy_matrix(B);
237 auto C_ = legacy_matrix(C);
240 constexpr Layout L = layout<matrixC_t>;
241 const auto& m = C_.m;
242 const auto& n = C_.n;
243 const auto& k = (transA == Op::NoTrans) ? A_.n : A_.m;
246 if (alpha == alpha_t(0))
248 -3,
"Infs and NaNs in A or B will not propagate to C on output");
249 if (beta == beta_t(0) && !is_same_v<beta_t, StrongZero>)
252 "Infs and NaNs in C on input will not propagate to C on output");
254 return ::blas::gemm((::blas::Layout)L, (::blas::Op)transA,
255 (::blas::Op)transB, m, n, k, alpha, A_.ptr, A_.ldim,
256 B_.ptr, B_.ldim, (T)beta, C_.ptr, C_.ldim);
261#if defined(TLAPACK_USE_BF16BF16FP32_GEMM) && __has_include(<stdfloat>) && __cplusplus > 202002L
263 #include <mkl_cblas.h>
281template <
class idx_t, Layout L>
291 auto A_ = legacy_matrix(A);
292 auto B_ = legacy_matrix(B);
293 auto C_ = legacy_matrix(C);
296 const CBLAS_LAYOUT
layout =
297 (L == Layout::ColMajor) ? CblasColMajor : CblasRowMajor;
298 const auto& m = C_.m;
299 const auto& n = C_.n;
300 const auto& k = (transA == Op::NoTrans) ? A_.n : A_.m;
302 assert(transA == Op::NoTrans);
303 assert(transB == Op::NoTrans);
305 cblas_gemm_bf16bf16f32(layout, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
306 reinterpret_cast<const uint16_t*
>(A_.ptr), A_.ldim,
307 reinterpret_cast<const uint16_t*
>(B_.ptr), B_.ldim,
308 beta, C_.ptr, C_.ldim);
constexpr Layout layout
Layout of a matrix or vector.
Definition arrayTraits.hpp:232
Op
Definition types.hpp:227
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SCALAR
Macro for tlapack::concepts::Scalar compatible with C++17.
Definition concepts.hpp:915
#define TLAPACK_LEGACY_MATRIX
Macro for tlapack::concepts::LegacyMatrix compatible with C++17.
Definition concepts.hpp:951
#define TLAPACK_MATRIX
Macro for tlapack::concepts::Matrix compatible with C++17.
Definition concepts.hpp:896
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
#define tlapack_warning(info, detailedInfo)
Warning handler.
Definition exceptionHandling.hpp:156
Concept for types that represent tlapack::Op.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Strong zero type.
Definition StrongZero.hpp:43