<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2017-2021, University of Tennessee. All rights reserved.
5// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
6//
7// This file is part of <T>LAPACK.
8// <T>LAPACK is free software: you can redistribute it and/or modify it under
9// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
10
11#ifndef TLAPACK_LEGACY_GEMM_HH
12#define TLAPACK_LEGACY_GEMM_HH
13
14#include "tlapack/blas/gemm.hpp"
17
18namespace tlapack {
19namespace legacy {
20
102 template <typename TA, typename TB, typename TC>
103 void gemm(Layout layout,
104 Op transA,
105 Op transB,
106 idx_t m,
107 idx_t n,
108 idx_t k,
110 TA const* A,
111 idx_t lda,
112 TB const* B,
113 idx_t ldb,
115 TC* C,
116 idx_t ldc)
117 {
118 using internal::create_matrix;
120
121 // redirect if row major
122 if (layout == Layout::RowMajor) {
123 return gemm(Layout::ColMajor, transB, transA, n, m, k, alpha, B,
124 ldb, A, lda, beta, C, ldc);
125 }
126
127 // check arguments
128 tlapack_check_false(layout != Layout::ColMajor &&
129 layout != Layout::RowMajor);
130 tlapack_check_false(transA != Op::NoTrans && transA != Op::Trans &&
131 transA != Op::ConjTrans);
132 tlapack_check_false(transB != Op::NoTrans && transB != Op::Trans &&
133 transB != Op::ConjTrans);
134 tlapack_check_false(m < 0);
135 tlapack_check_false(n < 0);
137 tlapack_check_false(lda < ((transA != Op::NoTrans) ? k : m));
138 tlapack_check_false(ldb < ((transB != Op::NoTrans) ? n : k));
140
141 // quick return
142 if (m == 0 || n == 0 ||
143 ((alpha == scalar_t(0) || k == 0) && (beta == scalar_t(1))))
144 return;
145
146 // Matrix views
147 const auto A_ = (transA == Op::NoTrans)
148 ? create_matrix<TA>((TA*)A, m, k, lda)
149 : create_matrix<TA>((TA*)A, k, m, lda);
150 const auto B_ = (transB == Op::NoTrans)
151 ? create_matrix<TB>((TB*)B, k, n, ldb)
152 : create_matrix<TB>((TB*)B, n, k, ldb);
153 auto C_ = create_matrix<TC>(C, m, n, ldc);
154
155 if (alpha == scalar_t(0)) {
156 if (beta == scalar_t(0)) {
157 for (idx_t j = 0; j < n; ++j)
158 for (idx_t i = 0; i < m; ++i)
159 C_(i, j) = TC(0);
160 }
161 else {
162 for (idx_t j = 0; j < n; ++j)
163 for (idx_t i = 0; i < m; ++i)
164 C_(i, j) *= beta;
165 }
166 }
167 else {
168 if (beta == scalar_t(0))
169 gemm(transA, transB, alpha, A_, B_, C_);
170 else
172 }
173 }
174
175} // namespace legacy
176} // namespace tlapack
177
178#endif // #ifndef TLAPACK_LEGACY_GEMM_HH
constexpr Layout layout
Layout of a matrix or vector.
Definition arrayTraits.hpp:232
Op
Definition types.hpp:222
Layout
Definition types.hpp:24
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
void gemm(Layout layout, Op transA, Op transB, idx_t m, idx_t n, idx_t k, scalar_type< TA, TB, TC > alpha, TA const *A, idx_t lda, TB const *B, idx_t ldb, scalar_type< TA, TB, TC > beta, TC *C, idx_t ldc)
General matrix-matrix multiply:
Definition gemm.hpp:103
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113