<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
trmm_blocked_mixed.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_TRMM_BLOCKED_MIXED_HH
11#define TLAPACK_TRMM_BLOCKED_MIXED_HH
12
14#include "tlapack/blas/gemm.hpp"
15#include "tlapack/blas/trmm.hpp"
17
18namespace tlapack {
19
24 size_t nb = 32;
25};
26
71template <TLAPACK_SIDE side_t,
81 op_t trans,
82 diag_t diag,
84 const matrixA_t& A,
85 matrixB_t& B,
86 work_t& work,
87 const TrmmBlockedOpts& opts = {})
88{
89 {
90 // data traits
91 using idx_t = size_type<matrixA_t>;
92 using range = std::pair<idx_t, idx_t>;
93
94 // constants
95 const idx_t m = nrows(B);
96 const idx_t n = ncols(B);
97 const idx_t nb = min(opts.nb, m);
98
99 // check arguments
100 tlapack_check_false(side != Side::Left && side != Side::Right);
101 tlapack_check_false(uplo != Uplo::Lower && uplo != Uplo::Upper);
102 tlapack_check_false(trans != Op::NoTrans && trans != Op::Trans &&
103 trans != Op::ConjTrans);
104 tlapack_check_false(diag != Diag::NonUnit && diag != Diag::Unit);
105 tlapack_check_false(nrows(A) != ncols(A));
106 tlapack_check_false(nrows(A) != ((side == Side::Left) ? m : n));
107
108 // Matrix W
109 auto [W, work1] = reshape(work, nb, n);
110
111 using real_t = real_type<type_t<matrixB_t>>;
112 if (side == Side::Left) {
113 if (trans == Op::NoTrans) {
114 if (uplo == Uplo::Upper) {
115 for (idx_t i = 0; i < m; i += nb) {
116 const idx_t ib = min(nb, m - i);
117
118 const auto A0i =
119 slice(A, range(0, i), range(i, i + ib));
120 const auto Aii =
121 slice(A, range(i, i + ib), range(i, i + ib));
122
123 auto B0 = rows(B, range(0, i));
124 auto Bi = rows(B, range(i, i + ib));
125 auto BiLowPrecision = rows(W, range(0, ib));
126
127 // B0 += alpha * A0i * Bi in mixed precision
128 lacpy(GENERAL, Bi, BiLowPrecision);
129 gemm(NO_TRANS, NO_TRANS, alpha, A0i, BiLowPrecision,
130 real_t(1), B0);
131
132 // Bi = alpha * Aii * Bi in mixed precision
133 trmm(side, uplo, trans, diag, alpha, Aii, Bi);
134 }
135 }
136 else { // uplo == Uplo::Lower
137 tlapack_error(1, "Blocked version of trsm not implemented");
138 }
139 }
140 else if (trans == Op::Trans) {
141 if (uplo == Uplo::Upper) {
142 tlapack_error(1, "Blocked version of trsm not implemented");
143 }
144 else { // uplo == Uplo::Lower
145 tlapack_error(1, "Blocked version of trsm not implemented");
146 }
147 }
148 else { // trans == Op::ConjTrans
149 if (uplo == Uplo::Upper) {
150 tlapack_error(1, "Blocked version of trsm not implemented");
151 }
152 else { // uplo == Uplo::Lower
153 tlapack_error(1, "Blocked version of trsm not implemented");
154 }
155 }
156 }
157 else { // side == Side::Right
158 if (trans == Op::NoTrans) {
159 if (uplo == Uplo::Upper) {
160 tlapack_error(1, "Blocked version of trsm not implemented");
161 }
162 else { // uplo == Uplo::Lower
163 tlapack_error(1, "Blocked version of trsm not implemented");
164 }
165 }
166 else if (trans == Op::Trans) {
167 if (uplo == Uplo::Upper) {
168 tlapack_error(1, "Blocked version of trsm not implemented");
169 }
170 else { // uplo == Uplo::Lower
171 tlapack_error(1, "Blocked version of trsm not implemented");
172 }
173 }
174 else { // trans == Op::ConjTrans
175 if (uplo == Uplo::Upper) {
176 tlapack_error(1, "Blocked version of trsm not implemented");
177 }
178 else { // uplo == Uplo::Lower
179 tlapack_error(1, "Blocked version of trsm not implemented");
180 }
181 }
182 }
183 }
184}
185
186} // namespace tlapack
187
188#endif
#define TLAPACK_DIAG
Macro for tlapack::concepts::Diag compatible with C++17.
Definition concepts.hpp:945
#define TLAPACK_SIDE
Macro for tlapack::concepts::Side compatible with C++17.
Definition concepts.hpp:927
#define TLAPACK_UPLO
Macro for tlapack::concepts::Uplo compatible with C++17.
Definition concepts.hpp:942
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_OP
Macro for tlapack::concepts::Op compatible with C++17.
Definition concepts.hpp:933
#define TLAPACK_WORKSPACE
Macro for tlapack::concepts::Workspace compatible with C++17.
Definition concepts.hpp:912
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
void trmm(Side side, Uplo uplo, Op trans, Diag diag, const alpha_t &alpha, const matrixA_t &A, matrixB_t &B)
Triangular matrix-matrix multiply:
Definition trmm.hpp:72
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
void trmm_blocked_mixed(side_t side, uplo_t uplo, op_t trans, diag_t diag, const scalar_type< type_t< matrixA_t >, type_t< matrixB_t > > &alpha, const matrixA_t &A, matrixB_t &B, work_t &work, const TrmmBlockedOpts &opts={})
Triangular matrix-matrix multiply using a blocked algorithm.
Definition trmm_blocked_mixed.hpp:78
#define tlapack_error(info, detailedInfo)
Error handler.
Definition exceptionHandling.hpp:142
#define tlapack_check_false(cond)
Throw an error if cond is true.
Definition exceptionHandling.hpp:113
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
typename traits::scalar_type_traits< Types..., int >::type scalar_type
The common scalar type of the list of types.
Definition scalar_type_traits.hpp:250
Options struct for trmm_blocked_mixed.
Definition trmm_blocked_mixed.hpp:23
size_t nb
Block size.
Definition trmm_blocked_mixed.hpp:24