<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
mdspan.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_MDSPAN_HH
11#define TLAPACK_MDSPAN_HH
12
13#include <cassert>
14#include <experimental/mdspan> // Use mdspan from
15 // https://github.com/kokkos/mdspan because we
16 // need the `submdspan` functionality
17
20
21namespace tlapack {
22
23// -----------------------------------------------------------------------------
24// Helpers
25
26namespace traits {
27 namespace internal {
28 template <class ET, class Exts, class LP, class AP>
29 std::true_type is_mdspan_type_f(
30 const std::experimental::mdspan<ET, Exts, LP, AP>*);
31
32 std::false_type is_mdspan_type_f(const void*);
33 } // namespace internal
34
37 template <class T>
38 constexpr bool is_mdspan_type =
39 decltype(internal::is_mdspan_type_f(std::declval<T*>()))::value;
40} // namespace traits
41
42// -----------------------------------------------------------------------------
43// Data traits
44
45namespace traits {
47 template <class ET, class Exts, class AP>
49 std::experimental::mdspan<ET, Exts, std::experimental::layout_left, AP>,
50 std::enable_if_t<Exts::rank() == 2, int>> {
51 static constexpr Layout value = Layout::ColMajor;
52 };
53 template <class ET, class Exts, class AP>
55 std::experimental::
56 mdspan<ET, Exts, std::experimental::layout_right, AP>,
57 std::enable_if_t<Exts::rank() == 2, int>> {
58 static constexpr Layout value = Layout::RowMajor;
59 };
60 template <class ET, class Exts, class LP, class AP>
62 std::experimental::mdspan<ET, Exts, LP, AP>,
63 std::enable_if_t<(Exts::rank() == 1) &&
64 LP::template mapping<Exts>::is_always_strided(),
65 int>> {
66 static constexpr Layout value = Layout::Strided;
67 };
68
69 template <class ET, class Exts, class LP, class AP>
70 struct real_type_traits<std::experimental::mdspan<ET, Exts, LP, AP>, int> {
71 using type = std::experimental::mdspan<real_type<ET>, Exts, LP, AP>;
72 };
73
74 template <class ET, class Exts, class LP, class AP>
75 struct complex_type_traits<std::experimental::mdspan<ET, Exts, LP, AP>,
76 int> {
77 using type = std::experimental::mdspan<complex_type<ET>, Exts, LP, AP>;
78 };
79
81 template <class ET, class Exts, class LP, class AP>
82 struct CreateFunctor<std::experimental::mdspan<ET, Exts, LP, AP>,
83 std::enable_if_t<Exts::rank() == 1, int>> {
84 using idx_t =
85 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
86 using extents_t = std::experimental::dextents<idx_t, 1>;
87
88 template <class T>
89 constexpr auto operator()(std::vector<T>& v, idx_t n) const
90 {
91 assert(n >= 0);
92 v.resize(n); // Allocates space in memory
93 return std::experimental::mdspan<T, extents_t>(v.data(), n);
94 }
95 };
96
98 template <class ET, class Exts, class LP, class AP, int n>
99 struct CreateStaticFunctor<std::experimental::mdspan<ET, Exts, LP, AP>,
100 n,
101 -1,
102 std::enable_if_t<Exts::rank() == 1, int>> {
103 static_assert(n >= 0);
104 using idx_t =
105 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
106 using extents_t = std::experimental::extents<idx_t, n>;
107
108 template <typename T>
109 constexpr auto operator()(T* v) const
110 {
111 return std::experimental::mdspan<T, extents_t>(v);
112 }
113 };
114
116 template <class ET, class Exts, class LP, class AP>
117 struct CreateFunctor<std::experimental::mdspan<ET, Exts, LP, AP>,
118 std::enable_if_t<Exts::rank() == 2, int>> {
119 using idx_t =
120 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
121 using extents_t = std::experimental::dextents<idx_t, 2>;
122
123 template <class T>
124 constexpr auto operator()(std::vector<T>& v, idx_t m, idx_t n) const
125 {
126 assert(m >= 0 && n >= 0);
127 v.resize(m * n); // Allocates space in memory
128 return std::experimental::mdspan<T, extents_t>(v.data(), m, n);
129 }
130 };
131
133 template <class ET, class Exts, class LP, class AP, int m, int n>
134 struct CreateStaticFunctor<std::experimental::mdspan<ET, Exts, LP, AP>,
135 m,
136 n,
137 std::enable_if_t<Exts::rank() == 2, int>> {
138 static_assert(m >= 0 && n >= 0);
139 using idx_t =
140 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
141 using extents_t = std::experimental::extents<idx_t, m, n>;
142
143 template <typename T>
144 constexpr auto operator()(T* v) const
145 {
146 return std::experimental::mdspan<T, extents_t>(v);
147 }
148 };
149} // namespace traits
150
151// -----------------------------------------------------------------------------
152// Data descriptors
153
154// Size
155template <class ET, class Exts, class LP, class AP>
156constexpr auto size(const std::experimental::mdspan<ET, Exts, LP, AP>& x)
157{
158 return x.size();
159}
160// Number of rows
161template <class ET, class Exts, class LP, class AP>
162constexpr auto nrows(const std::experimental::mdspan<ET, Exts, LP, AP>& x)
163{
164 return x.extent(0);
165}
166// Number of columns
167template <class ET, class Exts, class LP, class AP>
168constexpr auto ncols(const std::experimental::mdspan<ET, Exts, LP, AP>& x)
169{
170 return x.extent(1);
171}
172
173// -----------------------------------------------------------------------------
174// Block operations
175
176#define isSlice(SliceSpec) \
177 std::is_convertible<SliceSpec, std::tuple<std::size_t, std::size_t>>::value
178
179// Slice
180template <
181 class ET,
182 class Exts,
183 class LP,
184 class AP,
185 class SliceSpecRow,
186 class SliceSpecCol,
187 std::enable_if_t<isSlice(SliceSpecRow) || isSlice(SliceSpecCol), int> = 0>
188constexpr auto slice(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
189 SliceSpecRow&& rows,
190 SliceSpecCol&& cols) noexcept
191{
192 return std::experimental::submdspan(A, std::forward<SliceSpecRow>(rows),
193 std::forward<SliceSpecCol>(cols));
194}
195
196// Rows
197template <class ET,
198 class Exts,
199 class LP,
200 class AP,
201 class SliceSpec,
202 std::enable_if_t<isSlice(SliceSpec), int> = 0>
203constexpr auto rows(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
204 SliceSpec&& rows) noexcept
205{
206 return std::experimental::submdspan(A, std::forward<SliceSpec>(rows),
207 std::experimental::full_extent);
208}
209
210// Row
211template <class ET, class Exts, class LP, class AP>
212constexpr auto row(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
213 std::size_t rowIdx) noexcept
214{
215 return std::experimental::submdspan(A, rowIdx,
216 std::experimental::full_extent);
217}
218
219// Columns
220template <class ET,
221 class Exts,
222 class LP,
223 class AP,
224 class SliceSpec,
225 std::enable_if_t<isSlice(SliceSpec), int> = 0>
226constexpr auto cols(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
227 SliceSpec&& cols) noexcept
228{
229 return std::experimental::submdspan(A, std::experimental::full_extent,
230 std::forward<SliceSpec>(cols));
231}
232
233// Column
234template <class ET, class Exts, class LP, class AP>
235constexpr auto col(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
236 std::size_t colIdx) noexcept
237{
238 return std::experimental::submdspan(A, std::experimental::full_extent,
239 colIdx);
240}
241
242// Slice
243template <class ET,
244 class Exts,
245 class LP,
246 class AP,
247 class SliceSpec,
248 std::enable_if_t<isSlice(SliceSpec) && (Exts::rank() == 1), int> = 0>
249constexpr auto slice(const std::experimental::mdspan<ET, Exts, LP, AP>& v,
250 SliceSpec&& rows) noexcept
251{
252 return std::experimental::submdspan(v, std::forward<SliceSpec>(rows));
253}
254
255// Extract a diagonal from a matrix
256template <class ET,
257 class Exts,
258 class LP,
259 class AP,
260 std::enable_if_t<
261 /* Requires: */
262 LP::template mapping<Exts>::is_always_strided(),
263 bool> = true>
264constexpr auto diag(const std::experimental::mdspan<ET, Exts, LP, AP>& A,
265 int diagIdx = 0)
266{
267 using std::array;
268 using std::min;
269 using std::experimental::layout_stride;
270
271 using size_type =
272 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
273 using extents_t = std::experimental::dextents<size_type, 1>;
274 using mapping = typename layout_stride::template mapping<extents_t>;
275
276 // mdspan components
277 auto ptr = A.accessor().offset(A.data(), (diagIdx >= 0)
278 ? A.mapping()(0, diagIdx)
279 : A.mapping()(-diagIdx, 0));
280 auto map = mapping(
281 extents_t(
282 (diagIdx >= 0)
283 ? min(A.extent(0) + diagIdx, A.extent(1)) - (size_type)diagIdx
284 : min(A.extent(0), A.extent(1) - diagIdx) + (size_type)diagIdx),
285 array<size_type, 1>{A.stride(0) + A.stride(1)});
286 auto acc_pol = typename AP::offset_policy(A.accessor());
287
288 // return
289 return std::experimental::mdspan<ET, extents_t, layout_stride,
290 typename AP::offset_policy>(
291 std::move(ptr), std::move(map), std::move(acc_pol));
292}
293
294// Transpose View
295template <class ET, class Exts, class AP>
296constexpr auto transpose_view(
297 const std::experimental::
298 mdspan<ET, Exts, std::experimental::layout_left, AP>& A) noexcept
299{
300 using matrix_t =
301 std::experimental::mdspan<ET, Exts, std::experimental::layout_left, AP>;
302 using idx_t = typename matrix_t::size_type;
303 using extents_t =
304 std::experimental::extents<idx_t, matrix_t::static_extent(1),
305 matrix_t::static_extent(0)>;
306
307 using std::experimental::layout_right;
308 using mapping_t = typename layout_right::template mapping<extents_t>;
309
310 mapping_t map(extents_t(A.extent(1), A.extent(0)));
311 return std::experimental::mdspan<ET, extents_t, layout_right, AP>(
312 A.data(), std::move(map));
313}
314template <class ET, class Exts, class AP>
315constexpr auto transpose_view(
316 const std::experimental::
317 mdspan<ET, Exts, std::experimental::layout_right, AP>& A) noexcept
318{
319 using matrix_t =
320 std::experimental::mdspan<ET, Exts, std::experimental::layout_right,
321 AP>;
322 using idx_t = typename matrix_t::size_type;
323 using extents_t =
324 std::experimental::extents<idx_t, matrix_t::static_extent(1),
325 matrix_t::static_extent(0)>;
326
327 using std::experimental::layout_left;
328 using mapping_t = typename layout_left::template mapping<extents_t>;
329
330 mapping_t map(extents_t(A.extent(1), A.extent(0)));
331 return std::experimental::mdspan<ET, extents_t, layout_left, AP>(
332 A.data(), std::move(map));
333}
334template <class ET, class Exts, class AP>
335constexpr auto transpose_view(
336 const std::experimental::
337 mdspan<ET, Exts, std::experimental::layout_stride, AP>& A) noexcept
338{
339 using matrix_t =
340 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride,
341 AP>;
342 using idx_t = typename matrix_t::size_type;
343 using extents_t =
344 std::experimental::extents<idx_t, matrix_t::static_extent(1),
345 matrix_t::static_extent(0)>;
346
347 using std::experimental::layout_stride;
348 using mapping_t = typename layout_stride::template mapping<extents_t>;
349
350 mapping_t map(extents_t(A.extent(1), A.extent(0)),
351 std::array<idx_t, 2>{A.stride(1), A.stride(0)});
352 return std::experimental::mdspan<ET, extents_t, layout_stride, AP>(
353 A.data(), std::move(map));
354}
355
356// Reshape to matrix
357template <
358 class ET,
359 class Exts,
360 class LP,
361 class AP,
362 std::enable_if_t<(std::is_same_v<LP, std::experimental::layout_right> ||
363 std::is_same_v<LP, std::experimental::layout_left>),
364 int> = 0>
365auto reshape(std::experimental::mdspan<ET, Exts, LP, AP>& A,
366 std::size_t m,
367 std::size_t n)
368{
369 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
370 using extents1_t = std::experimental::dextents<idx_t, 1>;
371 using extents2_t = std::experimental::dextents<idx_t, 2>;
372 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
373 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
374 using mapping1_t = typename LP::template mapping<extents1_t>;
375 using mapping2_t = typename LP::template mapping<extents2_t>;
376
377 // constants
378 const idx_t size = A.size();
379 const idx_t new_size = m * n;
380
381 // Check arguments
382 if (new_size > size)
383 throw std::domain_error("New size is larger than current size");
384
385 return std::make_pair(
386 matrix_t(A.data(), mapping2_t(extents2_t(m, n))),
387 vector_t(A.data() + new_size, mapping1_t(extents1_t(size - new_size))));
388}
389template <class ET,
390 class Exts,
391 class AP,
392 std::enable_if_t<Exts::rank() == 2, int> = 0>
393auto reshape(
394 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
395 A,
396 std::size_t m,
397 std::size_t n)
398{
399 using LP = std::experimental::layout_stride;
400 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
401 using extents_t = std::experimental::dextents<idx_t, 2>;
402 using matrix_t = std::experimental::mdspan<ET, extents_t, LP>;
403 using mapping_t = typename LP::template mapping<extents_t>;
404
405 // constants
406 const idx_t size = A.size();
407 const idx_t new_size = m * n;
408 const bool is_contiguous =
409 (size <= 1) ||
410 (A.stride(0) == 1 &&
411 (A.stride(1) == A.extent(0) || A.extent(1) <= 1)) ||
412 (A.stride(1) == 1 && (A.stride(0) == A.extent(1) || A.extent(0) <= 1));
413
414 // Check arguments
415 if (new_size > size)
416 throw std::domain_error("New size is larger than current size");
417 if (A.stride(0) != 1 && A.stride(1) != 1)
418 throw std::domain_error(
419 "Reshaping is not available for matrices with both strides "
420 "different from 1.");
421
422 if (is_contiguous) {
423 const idx_t s = size - new_size;
424 if (A.stride(0) == 1)
425 return std::make_pair(
426 matrix_t(A.data(), mapping_t(extents_t(m, n),
427 std::array<idx_t, 2>{1, m})),
428 matrix_t(
429 A.data() + new_size,
430 mapping_t(extents_t(s, 1), std::array<idx_t, 2>{1, s})));
431 else
432 return std::make_pair(
433 matrix_t(A.data(), mapping_t(extents_t(m, n),
434 std::array<idx_t, 2>{n, 1})),
435 matrix_t(
436 A.data() + new_size,
437 mapping_t(extents_t(1, s), std::array<idx_t, 2>{s, 1})));
438 }
439 else {
440 std::array<idx_t, 2> strides{A.stride(0), A.stride(1)};
441
442 if (m == A.extent(0) || n == 0) {
443 return std::make_pair(
444 matrix_t(A.data(), mapping_t(extents_t(m, n), strides)),
445 matrix_t(A.data() + n * A.stride(1),
446 mapping_t(extents_t(m, A.extent(1) - n), strides)));
447 }
448 else if (n == A.extent(1) || m == 0) {
449 return std::make_pair(
450 matrix_t(A.data(), mapping_t(extents_t(m, n), strides)),
451 matrix_t(A.data() + m * A.stride(0),
452 mapping_t(extents_t(A.extent(0) - m, n), strides)));
453 }
454 else {
455 throw std::domain_error(
456 "Cannot reshape to non-contiguous matrix if the number of rows "
457 "and "
458 "columns are different.");
459 }
460 }
461}
462template <class ET,
463 class Exts,
464 class AP,
465 std::enable_if_t<Exts::rank() == 1, int> = 0>
466auto reshape(
467 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
468 v,
469 std::size_t m,
470 std::size_t n)
471{
472 using LP = std::experimental::layout_stride;
473 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
474 using extents1_t = std::experimental::dextents<idx_t, 1>;
475 using extents2_t = std::experimental::dextents<idx_t, 2>;
476 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
477 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
478 using mapping1_t = typename LP::template mapping<extents1_t>;
479 using mapping2_t = typename LP::template mapping<extents2_t>;
480
481 // constants
482 const idx_t size = v.size();
483 const idx_t new_size = m * n;
484 const idx_t s = size - new_size;
485 const idx_t stride = v.stride(0);
486 const bool is_contiguous = (size <= 1 || stride == 1);
487
488 // Check arguments
489 if (new_size > size)
490 throw std::domain_error("New size is larger than current size");
491 if (!is_contiguous && m > 1 && n > 1)
492 throw std::domain_error(
493 "New sizes are not compatible with the current vector.");
494
495 if (is_contiguous) {
496 return std::make_pair(
497 matrix_t(v.data(),
498 mapping2_t(extents2_t(m, n), std::array<idx_t, 2>{1, m})),
499 vector_t(v.data() + new_size,
500 mapping1_t(extents1_t(s), std::array<idx_t, 1>{1})));
501 }
502 else {
503 return std::make_pair(
504 matrix_t(v.data(),
505 mapping2_t(extents2_t(m, n),
506 (m <= 1) ? std::array<idx_t, 2>{1, stride}
507 : std::array<idx_t, 2>{stride, 1})),
508 vector_t(v.data() + new_size * stride,
509 mapping1_t(extents1_t(s), std::array<idx_t, 1>{stride})));
510 }
511}
512
513// Reshape to vector
514template <
515 class ET,
516 class Exts,
517 class LP,
518 class AP,
519 std::enable_if_t<(std::is_same_v<LP, std::experimental::layout_right> ||
520 std::is_same_v<LP, std::experimental::layout_left>),
521 int> = 0>
522auto reshape(std::experimental::mdspan<ET, Exts, LP, AP>& A, std::size_t n)
523{
524 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
525 using extents_t = std::experimental::dextents<idx_t, 1>;
526 using vector_t = std::experimental::mdspan<ET, extents_t, LP>;
527 using mapping_t = typename LP::template mapping<extents_t>;
528
529 // constants
530 const idx_t size = A.size();
531
532 // Check arguments
533 if (n > size)
534 throw std::domain_error("New size is larger than current size");
535
536 return std::make_pair(
537 vector_t(A.data(), mapping_t(extents_t(n))),
538 vector_t(A.data() + n, mapping_t(extents_t(size - n))));
539}
540template <class ET,
541 class Exts,
542 class AP,
543 std::enable_if_t<Exts::rank() == 2, int> = 0>
544auto reshape(
545 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
546 A,
547 std::size_t n)
548{
549 using LP = std::experimental::layout_stride;
550 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
551 using extents1_t = std::experimental::dextents<idx_t, 1>;
552 using extents2_t = std::experimental::dextents<idx_t, 2>;
553 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
554 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
555 using mapping1_t = typename LP::template mapping<extents1_t>;
556 using mapping2_t = typename LP::template mapping<extents2_t>;
557
558 // constants
559 const idx_t size = A.size();
560 const idx_t s = size - n;
561 const bool is_contiguous =
562 (size <= 1) ||
563 (A.stride(0) == 1 &&
564 (A.stride(1) == A.extent(0) || A.extent(1) <= 1)) ||
565 (A.stride(1) == 1 && (A.stride(0) == A.extent(1) || A.extent(0) <= 1));
566
567 // Check arguments
568 if (n > size)
569 throw std::domain_error("New size is larger than current size");
570 if (A.stride(0) != 1 && A.stride(1) != 1)
571 throw std::domain_error(
572 "Reshaping is not available for matrices with both strides "
573 "different from 1.");
574
575 if (is_contiguous) {
576 return std::make_pair(
577 vector_t(A.data(),
578 mapping1_t(extents1_t(n), std::array<idx_t, 1>{1})),
579 matrix_t(
580 A.data() + n,
581 (A.stride(0) == 1)
582 ? mapping2_t(extents2_t(s, 1), std::array<idx_t, 2>{1, s})
583 : mapping2_t(extents2_t(1, s),
584 std::array<idx_t, 2>{s, 1})));
585 }
586 else {
587 std::array<idx_t, 2> strides{A.stride(0), A.stride(1)};
588
589 if (n == 0) {
590 return std::make_pair(
591 vector_t(A.data(),
592 mapping1_t(extents1_t(0), std::array<idx_t, 1>{1})),
593 matrix_t(A.data(), mapping2_t(A.extents(), strides)));
594 }
595 else if (n == A.extent(0)) {
596 return std::make_pair(
597 vector_t(A.data(),
598 mapping1_t(extents1_t(n),
599 std::array<idx_t, 1>{A.stride(0)})),
600 matrix_t(A.data() + A.stride(1),
601 mapping2_t(extents2_t(A.extent(0), A.extent(1) - 1),
602 strides)));
603 }
604 else if (n == A.extent(1)) {
605 return std::make_pair(
606 vector_t(A.data(),
607 mapping1_t(extents1_t(n),
608 std::array<idx_t, 1>{A.stride(1)})),
609 matrix_t(A.data() + A.stride(0),
610 mapping2_t(extents2_t(A.extent(0) - 1, A.extent(1)),
611 strides)));
612 }
613 else {
614 throw std::domain_error(
615 "Cannot reshape to non-contiguous matrix if the number of rows "
616 "and "
617 "columns are different.");
618 }
619 }
620}
621template <class ET,
622 class Exts,
623 class AP,
624 std::enable_if_t<Exts::rank() == 1, int> = 0>
625auto reshape(
626 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
627 v,
628 std::size_t n)
629{
630 using LP = std::experimental::layout_stride;
631 using idx_t = typename std::experimental::mdspan<ET, Exts, LP>::size_type;
632 using extents_t = std::experimental::dextents<idx_t, 1>;
633 using vector_t = std::experimental::mdspan<ET, extents_t, LP>;
634 using mapping_t = typename LP::template mapping<extents_t>;
635
636 // constants
637 const std::array<idx_t, 1> stride{v.stride(0)};
638
639 // Check arguments
640 if (n > v.size())
641 throw std::domain_error("New size is larger than current size");
642
643 return std::make_pair(
644 vector_t(v.data(), mapping_t(extents_t(n), stride)),
645 vector_t(v.data() + n, mapping_t(extents_t(v.size() - n), stride)));
646}
647
648#undef isSlice
649
650// -----------------------------------------------------------------------------
651// Deduce matrix and vector type from two provided ones
652
653namespace traits {
654
655 template <typename T>
656 constexpr bool cast_to_mdspan_type =
657 is_mdspan_type<T> || is_stdvector_type<T>
658#ifdef TLAPACK_EIGEN_HH
659 || is_eigen_type<T>
660#endif
661#ifdef TLAPACK_LEGACYARRAY_HH
662 || is_legacy_type<T>
663#endif
664 ;
665
666 // for two types
667 // should be especialized for every new matrix class
668 template <class matrixA_t, class matrixB_t>
670 matrixA_t,
671 matrixB_t,
672 typename std::enable_if<is_mdspan_type<matrixA_t> &&
673 is_mdspan_type<matrixB_t> &&
674 (layout<matrixA_t> == layout<matrixB_t>),
675 int>::type> {
677 using idx_t = size_type<matrixA_t>;
678 using extents_t = std::experimental::dextents<idx_t, 2>;
679
680 using type = std::experimental::
681 mdspan<T, extents_t, typename matrixA_t::layout_type>;
682 };
683 template <class matrixA_t, class matrixB_t>
685 matrixA_t,
686 matrixB_t,
687 typename std::enable_if<
688
689 (is_mdspan_type<matrixA_t> && is_mdspan_type<matrixB_t> &&
690 !(layout<matrixA_t> == layout<matrixB_t>))
691
692 ||
693
694 ((is_mdspan_type<matrixA_t> || is_mdspan_type<matrixB_t>)&&(
695 !is_mdspan_type<matrixA_t> ||
696 !is_mdspan_type<
697 matrixB_t>)&&cast_to_mdspan_type<matrixA_t> &&
698 cast_to_mdspan_type<matrixB_t>),
699 int>::type> {
701 using idx_t = size_type<matrixA_t>;
702 using extents_t = std::experimental::dextents<idx_t, 2>;
703
704 using type = std::experimental::
705 mdspan<T, extents_t, std::experimental::layout_stride>;
706 };
707
708 // for two types
709 // should be especialized for every new vector class
710 template <class matrixA_t, class matrixB_t>
712 matrixA_t,
713 matrixB_t,
714 typename std::enable_if<is_mdspan_type<matrixA_t> &&
715 is_mdspan_type<matrixB_t> &&
716 (layout<matrixA_t> == layout<matrixB_t>),
717 int>::type> {
719 using idx_t = size_type<matrixA_t>;
720 using extents_t = std::experimental::dextents<idx_t, 1>;
721
722 using type = std::experimental::
723 mdspan<T, extents_t, typename matrixA_t::layout_type>;
724 };
725 template <class matrixA_t, class matrixB_t>
727 matrixA_t,
728 matrixB_t,
729 typename std::enable_if<
730
731 (is_mdspan_type<matrixA_t> && is_mdspan_type<matrixB_t> &&
732 !(layout<matrixA_t> == layout<matrixB_t>))
733
734 ||
735
736 ((is_mdspan_type<matrixA_t> || is_mdspan_type<matrixB_t>)&&(
737 !is_mdspan_type<matrixA_t> ||
738 !is_mdspan_type<
739 matrixB_t>)&&cast_to_mdspan_type<matrixA_t> &&
740 cast_to_mdspan_type<matrixB_t>),
741 int>::type> {
743 using idx_t = size_type<matrixA_t>;
744 using extents_t = std::experimental::dextents<idx_t, 1>;
745
746 using type = std::experimental::
747 mdspan<T, extents_t, std::experimental::layout_stride>;
748 };
749
750#if !defined(TLAPACK_EIGEN_HH) && !defined(TLAPACK_LEGACYARRAY_HH)
751 template <class vecA_t, class vecB_t>
752 struct matrix_type_traits<
753 vecA_t,
754 vecB_t,
755 std::enable_if_t<traits::is_stdvector_type<vecA_t> &&
756 traits::is_stdvector_type<vecB_t>,
757 int>> {
759 using extents_t = std::experimental::dextents<std::size_t, 2>;
760
761 using type = std::experimental::
762 mdspan<T, extents_t, std::experimental::layout_left>;
763 };
764
765 template <class vecA_t, class vecB_t>
766 struct vector_type_traits<
767 vecA_t,
768 vecB_t,
769 std::enable_if_t<traits::is_stdvector_type<vecA_t> &&
770 traits::is_stdvector_type<vecB_t>,
771 int>> {
773 using extents_t = std::experimental::dextents<std::size_t, 1>;
774
775 using type = std::experimental::
776 mdspan<T, extents_t, std::experimental::layout_left>;
777 };
778#endif
779
780} // namespace traits
781
782// -----------------------------------------------------------------------------
783// Cast to Legacy arrays
784
785template <class ET,
786 class Exts,
787 class LP,
788 class AP,
789 std::enable_if_t<Exts::rank() == 2 &&
790 LP::template mapping<Exts>::is_always_strided(),
791 int> = 0>
792constexpr auto legacy_matrix(
793 const std::experimental::mdspan<ET, Exts, LP, AP>& A) noexcept
794{
795 using idx_t =
796 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
797
798 // Here we do not use layout<std::experimental::mdspan<ET, Exts, LP, AP>>
799 // on purpose. This is because we want to allow legacy_matrix to be used
800 // with mdspan objects where the strides are defined at runtime.
801 const Layout L = (A.stride(0) == 1 && A.stride(1) >= A.extent(0))
802 ? Layout::ColMajor
803 : Layout::RowMajor;
804
805 assert((A.stride(0) == 1 && A.stride(1) >= A.extent(0)) || // col major
806 (A.stride(1) == 1 && A.stride(0) >= A.extent(1))); // row major
807
808 return legacy::Matrix<ET, idx_t>{
809 L, A.extent(0), A.extent(1), A.data(),
810 (L == Layout::ColMajor) ? A.stride(1) : A.stride(0)};
811}
812
813template <class ET,
814 class Exts,
815 class LP,
816 class AP,
817 std::enable_if_t<Exts::rank() == 1 &&
818 LP::template mapping<Exts>::is_always_strided(),
819 int> = 0>
820constexpr auto legacy_matrix(
821 const std::experimental::mdspan<ET, Exts, LP, AP>& A) noexcept
822{
823 using idx_t =
824 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
825 return legacy::Matrix<ET, idx_t>{Layout::ColMajor, 1, A.size(), A.data(),
826 A.stride(0)};
827}
828
829template <class ET,
830 class Exts,
831 class LP,
832 class AP,
833 std::enable_if_t<Exts::rank() == 1 &&
834 LP::template mapping<Exts>::is_always_strided(),
835 int> = 0>
836constexpr auto legacy_vector(
837 const std::experimental::mdspan<ET, Exts, LP, AP>& A) noexcept
838{
839 using idx_t =
840 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
841 return legacy::Vector<ET, idx_t>{A.size(), A.data(), A.stride(0)};
842}
843
844} // namespace tlapack
845
846#endif // TLAPACK_MDSPAN_HH
typename traits::size_type_trait< T, int >::type size_type
Size type of a matrix or vector.
Definition arrayTraits.hpp:228
Layout
Definition types.hpp:24
@ RowMajor
Row-major layout.
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
constexpr bool is_mdspan_type
True if T is a mdspan array.
Definition mdspan.hpp:38
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Functor for data creation.
Definition arrayTraits.hpp:89
constexpr auto operator()(std::vector< T > &v, idx_t m, idx_t n=1) const
Creates a m-by-n matrix with entries of type T.
Definition arrayTraits.hpp:105
Functor for data creation with static size.
Definition arrayTraits.hpp:141
constexpr auto operator()(T *v) const
Creates a m-by-n matrix or, if n == -1, a vector of size m.
Definition arrayTraits.hpp:157
Complex type traits for the list of types Types.
Definition scalar_type_traits.hpp:145
Trait to determine the layout of a given data structure.
Definition arrayTraits.hpp:75
Matrix type deduction.
Definition arrayTraits.hpp:176
Real type traits for the list of types Types.
Definition scalar_type_traits.hpp:71
Vector type deduction.
Definition arrayTraits.hpp:203