<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_STARPU_GEMM_HH
11#define TLAPACK_STARPU_GEMM_HH
12
16
17namespace tlapack {
18
20template <class TA, class TB, class TC, class alpha_t, class beta_t>
22 Op transB,
23 const alpha_t& alpha,
24 const starpu::Matrix<TA>& A,
25 const starpu::Matrix<TB>& B,
26 const beta_t& beta,
28
29{
30 using starpu::idx_t;
31
32 // constants
33 const real_type<TC> one(1);
34 const idx_t m = C.nrows();
35 const idx_t n = C.ncols();
36 const idx_t k = (transA == Op::NoTrans) ? A.ncols() : A.nrows();
37 const idx_t nx = C.get_nx();
38 const idx_t ny = C.get_ny();
39 const idx_t nz = (transA == Op::NoTrans) ? A.get_ny() : A.get_nx();
40
41 // quick return
42 if (m == 0 || n == 0) return;
43 if (k == 0 && beta == one) return;
44
45 // check arguments
46 tlapack_check(transA == Op::NoTrans || transA == Op::Trans ||
47 transA == Op::ConjTrans);
48 tlapack_check(transB == Op::NoTrans || transB == Op::Trans ||
49 transB == Op::ConjTrans);
50 tlapack_check(m == (transA == Op::NoTrans ? A.nrows() : A.ncols()));
51 tlapack_check(nx == (transA == Op::NoTrans ? A.get_nx() : A.get_ny()));
52 tlapack_check(n == (transB == Op::NoTrans ? B.ncols() : B.nrows()));
53 tlapack_check(k == (transB == Op::NoTrans ? B.nrows() : B.ncols()));
54 tlapack_check(ny == (transB == Op::NoTrans ? B.get_ny() : B.get_nx()));
55 tlapack_check(nz == (transB == Op::NoTrans ? B.get_nx() : B.get_ny()));
56
57 // Remove const type from A and B
58 auto& A_ = const_cast<starpu::Matrix<TA>&>(A);
59 auto& B_ = const_cast<starpu::Matrix<TB>&>(B);
60
61 for (idx_t ix = 0; ix < nx; ++ix) {
62 for (idx_t iy = 0; iy < ny; ++iy) {
63 if (transA == Op::NoTrans) {
64 if (transB == Op::NoTrans) {
65 starpu::insert_task_gemm<TA, TB, TC>(
66 transA, transB, alpha, A_.tile(ix, 0), B_.tile(0, iy),
67 beta, C.tile(ix, iy));
68 for (idx_t iz = 1; iz < nz; ++iz)
69 starpu::insert_task_gemm<TA, TB, TC>(
70 transA, transB, alpha, A_.tile(ix, iz),
71 B_.tile(iz, iy), one, C.tile(ix, iy));
72 }
73 else {
74 starpu::insert_task_gemm<TA, TB, TC>(
75 transA, transB, alpha, A_.tile(ix, 0), B_.tile(iy, 0),
76 beta, C.tile(ix, iy));
77 for (idx_t iz = 1; iz < nz; ++iz)
78 starpu::insert_task_gemm<TA, TB, TC>(
79 transA, transB, alpha, A_.tile(ix, iz),
80 B_.tile(iy, iz), one, C.tile(ix, iy));
81 }
82 }
83 else {
84 if (transB == Op::NoTrans) {
85 starpu::insert_task_gemm<TA, TB, TC>(
86 transA, transB, alpha, A_.tile(0, ix), B_.tile(0, iy),
87 beta, C.tile(ix, iy));
88 for (idx_t iz = 1; iz < nz; ++iz)
89 starpu::insert_task_gemm<TA, TB, TC>(
90 transA, transB, alpha, A_.tile(iz, ix),
91 B_.tile(iz, iy), one, C.tile(ix, iy));
92 }
93 else {
94 starpu::insert_task_gemm<TA, TB, TC>(
95 transA, transB, alpha, A_.tile(0, ix), B_.tile(iy, 0),
96 beta, C.tile(ix, iy));
97 for (idx_t iz = 1; iz < nz; ++iz)
98 starpu::insert_task_gemm<TA, TB, TC>(
99 transA, transB, alpha, A_.tile(iz, ix),
100 B_.tile(iy, iz), one, C.tile(ix, iy));
101 }
102 }
103 }
104 }
105}
106
107} // namespace tlapack
108
109#endif // TLAPACK_STARPU_GEMM_HH
Op
Definition types.hpp:222
Class for representing a matrix in StarPU that is split into tiles.
Definition Matrix.hpp:133
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Task insertion functions.