<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
rot_sequence.hpp
Go to the documentation of this file.
1
5//
6// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
7//
8// This file is part of <T>LAPACK.
9// <T>LAPACK is free software: you can redistribute it and/or modify it under
10// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
11
12#ifndef TLAPACK_ROT_SEQUENCE_HH
13#define TLAPACK_ROT_SEQUENCE_HH
14
16#include "tlapack/blas/rot.hpp"
17
18namespace tlapack {
19
76template <TLAPACK_SIDE side_t,
77 TLAPACK_DIRECTION direction_t,
82 side_t side, direction_t direction, const C_t& c, const S_t& s, A_t& A)
83{
84 using T = type_t<A_t>;
85 using idx_t = size_type<A_t>;
86
87 // constants
88 const idx_t m = nrows(A);
89 const idx_t n = ncols(A);
90 const idx_t k = (side == Side::Left) ? m - 1 : n - 1;
91
92 // Check dimensions
93 tlapack_check((idx_t)size(c) == k);
94 tlapack_check((idx_t)size(s) == k);
95
96 // quick return
97 if (k < 1) return 0;
98
99 if constexpr (layout<A_t> == Layout::ColMajor) {
100 if (direction == Direction::Forward) {
101 if (side == Side::Left) {
102 for (idx_t j = 0; j < n; ++j) {
103 for (idx_t i2 = k; i2 > 0; --i2) {
104 idx_t i = i2 - 1;
105 T temp = c[i] * A(i, j) + s[i] * A(i + 1, j);
106 A(i + 1, j) =
107 -conj(s[i]) * A(i, j) + c[i] * A(i + 1, j);
108 A(i, j) = temp;
109 }
110 }
111 }
112 else { // Side::Right
113 // Manual unrolling of loop, applying 3 rotations at a time
114 // This allows some parts of the vector to remain in register
115 idx_t ii = k % 3;
116 for (idx_t i2 = k; i2 > ii; i2 = i2 - 3) {
117 idx_t i = i2 - 1;
118
119 for (idx_t j = 0; j < m; ++j) {
120 T temp = A(j, i + 1);
121 T temp0 = A(j, i);
122 T temp1 = A(j, i - 1);
123
124 // Apply first rotation
125 A(j, i + 1) = -s[i] * temp0 + c[i] * temp;
126 temp0 = c[i] * temp0 + conj(s[i]) * temp;
127
128 // Apply second rotation
129 A(j, i) = -s[i - 1] * temp1 + c[i - 1] * temp0;
130 temp1 = c[i - 1] * temp1 + conj(s[i - 1]) * temp0;
131
132 // Apply third rotation
133 A(j, i - 1) =
134 -s[i - 2] * A(j, i - 2) + c[i - 2] * temp1;
135 A(j, i - 2) =
136 c[i - 2] * A(j, i - 2) + conj(s[i - 2]) * temp1;
137 }
138 }
139 // If the amount of rotations is not divisible by 3, apply the
140 // final ones one by one
141 for (idx_t i2 = ii; i2 > 0; --i2) {
142 idx_t i = i2 - 1;
143 for (idx_t j = 0; j < m; ++j) {
144 T temp = c[i] * A(j, i) + conj(s[i]) * A(j, i + 1);
145 A(j, i + 1) = -s[i] * A(j, i) + c[i] * A(j, i + 1);
146 A(j, i) = temp;
147 }
148 }
149 }
150 }
151 else { // Direction::Backward
152 if (side == Side::Left) {
153 for (idx_t j = 0; j < n; ++j) {
154 for (idx_t i = 0; i < k; ++i) {
155 T temp = c[i] * A(i, j) + s[i] * A(i + 1, j);
156 A(i + 1, j) =
157 -conj(s[i]) * A(i, j) + c[i] * A(i + 1, j);
158 A(i, j) = temp;
159 }
160 }
161 }
162 else { // Side::Right
163
164 // Manual unrolling of loop, applying 3 rotations at a time
165 // This allows some parts of the vector to remain in register
166 idx_t ii = k - (k % 3);
167 for (idx_t i = 0; i + 1 < ii; i = i + 3) {
168 for (idx_t j = 0; j < m; ++j) {
169 T temp = A(j, i);
170 T temp0 = A(j, i + 1);
171 T temp1 = A(j, i + 2);
172
173 // Apply first rotation
174 A(j, i) = c[i] * temp + conj(s[i]) * temp0;
175 temp0 = -s[i] * temp + c[i] * temp0;
176
177 // Apply second rotation
178 A(j, i + 1) = c[i + 1] * temp0 + conj(s[i + 1]) * temp1;
179 temp1 = -s[i + 1] * temp0 + c[i + 1] * temp1;
180
181 // Apply third rotation
182 A(j, i + 2) =
183 c[i + 2] * temp1 + conj(s[i + 2]) * A(j, i + 3);
184 A(j, i + 3) =
185 -s[i + 2] * temp1 + c[i + 2] * A(j, i + 3);
186 }
187 }
188 // If the amount of rotations is not divisible by 3, apply the
189 // final ones one by one
190 for (idx_t i = ii; i < k; ++i) {
191 for (idx_t j = 0; j < m; ++j) {
192 T temp = c[i] * A(j, i) + conj(s[i]) * A(j, i + 1);
193 A(j, i + 1) = -s[i] * A(j, i) + c[i] * A(j, i + 1);
194 A(j, i) = temp;
195 }
196 }
197 }
198 }
199 }
200 else {
201 if (direction == Direction::Forward) {
202 if (side == Side::Left) {
203 // Manual unrolling of loop, applying 3 rotations at a time
204 // This allows some parts of the vector to remain in register
205 idx_t ii = k % 3;
206 for (idx_t i2 = k; i2 > ii; i2 = i2 - 3) {
207 idx_t i = i2 - 1;
208
209 for (idx_t j = 0; j < n; ++j) {
210 T temp = A(i + 1, j);
211 T temp0 = A(i, j);
212 T temp1 = A(i - 1, j);
213
214 // Apply first rotation
215 A(i + 1, j) = -conj(s[i]) * temp0 + c[i] * temp;
216 temp0 = c[i] * temp0 + s[i] * temp;
217
218 // Apply second rotation
219 A(i, j) = -conj(s[i - 1]) * temp1 + c[i - 1] * temp0;
220 temp1 = c[i - 1] * temp1 + s[i - 1] * temp0;
221
222 // Apply third rotation
223 A(i - 1, j) =
224 -conj(s[i - 2]) * A(i - 2, j) + c[i - 2] * temp1;
225 A(i - 2, j) = c[i - 2] * A(i - 2, j) + s[i - 2] * temp1;
226 }
227 }
228 // If the amount of rotations is not divisible by 3, apply the
229 // final ones one by one
230 for (idx_t i2 = ii; i2 > 0; --i2) {
231 idx_t i = i2 - 1;
232 for (idx_t j = 0; j < n; ++j) {
233 T temp = c[i] * A(i, j) + s[i] * A(i + 1, j);
234 A(i + 1, j) =
235 -conj(s[i]) * A(i, j) + c[i] * A(i + 1, j);
236 A(i, j) = temp;
237 }
238 }
239 }
240 else {
241 for (idx_t j = 0; j < m; ++j) {
242 for (idx_t i2 = k; i2 > 0; --i2) {
243 idx_t i = i2 - 1;
244 T temp = c[i] * A(j, i) + conj(s[i]) * A(j, i + 1);
245 A(j, i + 1) = -s[i] * A(j, i) + c[i] * A(j, i + 1);
246 A(j, i) = temp;
247 }
248 }
249 }
250 }
251 else {
252 if (side == Side::Left) {
253 // Manual unrolling of loop, applying 3 rotations at a time
254 // This allows some parts of the vector to remain in register
255 idx_t ii = k - (k % 3);
256 for (idx_t i = 0; i + 1 < ii; i = i + 3) {
257 for (idx_t j = 0; j < n; ++j) {
258 T temp = A(i, j);
259 T temp0 = A(i + 1, j);
260 T temp1 = A(i + 2, j);
261
262 // Apply first rotation
263 A(i, j) = c[i] * temp + s[i] * temp0;
264 temp0 = -conj(s[i]) * temp + c[i] * temp0;
265
266 // Apply second rotation
267 A(i + 1, j) = c[i + 1] * temp0 + s[i + 1] * temp1;
268 temp1 = -conj(s[i + 1]) * temp0 + c[i + 1] * temp1;
269
270 // Apply third rotation
271 A(i + 2, j) = c[i + 2] * temp1 + s[i + 2] * A(i + 3, j);
272 A(i + 3, j) =
273 -conj(s[i + 2]) * temp1 + c[i + 2] * A(i + 3, j);
274 }
275 }
276 // If the amount of rotations is not divisible by 3, apply the
277 // final ones one by one
278 for (idx_t i = ii; i < k; ++i) {
279 for (idx_t j = 0; j < n; ++j) {
280 T temp = c[i] * A(i, j) + s[i] * A(i + 1, j);
281 A(i + 1, j) =
282 -conj(s[i]) * A(i, j) + c[i] * A(i + 1, j);
283 A(i, j) = temp;
284 }
285 }
286 }
287 else {
288 for (idx_t j = 0; j < m; ++j) {
289 for (idx_t i = 0; i < k; ++i) {
290 T temp = c[i] * A(j, i) + conj(s[i]) * A(j, i + 1);
291 A(j, i + 1) = -s[i] * A(j, i) + c[i] * A(j, i + 1);
292 A(j, i) = temp;
293 }
294 }
295 }
296 }
297 }
298
299 return 0;
300}
301
302} // namespace tlapack
303
304#endif // TLAPACK_ROT_SEQUENCE_HH
constexpr T conj(const T &x) noexcept
Extends std::conj() to real datatypes.
Definition utils.hpp:100
#define TLAPACK_SVECTOR
Macro for tlapack::concepts::SliceableVector compatible with C++17.
Definition concepts.hpp:909
#define TLAPACK_SIDE
Macro for tlapack::concepts::Side compatible with C++17.
Definition concepts.hpp:927
#define TLAPACK_SMATRIX
Macro for tlapack::concepts::SliceableMatrix compatible with C++17.
Definition concepts.hpp:899
#define TLAPACK_DIRECTION
Macro for tlapack::concepts::Direction compatible with C++17.
Definition concepts.hpp:930
int rot_sequence(side_t side, direction_t direction, const C_t &c, const S_t &s, A_t &A)
Applies a sequence of plane rotations to an (m-by-n) matrix.
Definition rot_sequence.hpp:81
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113