10#ifndef TLAPACK_MDSPAN_HH
11#define TLAPACK_MDSPAN_HH
14#include <experimental/mdspan>
28 template <
class ET,
class Exts,
class LP,
class AP>
29 std::true_type is_mdspan_type_f(
30 const std::experimental::mdspan<ET, Exts, LP, AP>*);
32 std::false_type is_mdspan_type_f(
const void*);
39 decltype(internal::is_mdspan_type_f(std::declval<T*>()))::value;
47 template <
class ET,
class Exts,
class AP>
49 std::experimental::mdspan<ET, Exts, std::experimental::layout_left, AP>,
50 std::enable_if_t<Exts::rank() == 2, int>> {
51 static constexpr Layout value = Layout::ColMajor;
53 template <
class ET,
class Exts,
class AP>
56 mdspan<ET, Exts, std::experimental::layout_right, AP>,
57 std::enable_if_t<Exts::rank() == 2, int>> {
58 static constexpr Layout value = Layout::RowMajor;
60 template <
class ET,
class Exts,
class LP,
class AP>
62 std::experimental::mdspan<ET, Exts, LP, AP>,
63 std::enable_if_t<(Exts::rank() == 1) &&
64 LP::template mapping<Exts>::is_always_strided(),
66 static constexpr Layout value = Layout::Strided;
69 template <
class ET,
class Exts,
class LP,
class AP>
71 using type = std::experimental::mdspan<real_type<ET>,
Exts,
LP,
AP>;
74 template <
class ET,
class Exts,
class LP,
class AP>
77 using type = std::experimental::mdspan<complex_type<ET>,
Exts,
LP,
AP>;
81 template <
class ET,
class Exts,
class LP,
class AP>
83 std::enable_if_t<Exts::rank() == 1, int>> {
85 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
86 using extents_t = std::experimental::dextents<idx_t, 1>;
89 constexpr auto operator()(std::vector<T>&
v, idx_t n)
const
93 return std::experimental::mdspan<T, extents_t>(
v.data(), n);
98 template <
class ET,
class Exts,
class LP,
class AP,
int n>
102 std::enable_if_t<Exts::rank() == 1, int>> {
103 static_assert(n >= 0);
105 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
106 using extents_t = std::experimental::extents<idx_t, n>;
108 template <
typename T>
111 return std::experimental::mdspan<T, extents_t>(
v);
116 template <
class ET,
class Exts,
class LP,
class AP>
118 std::enable_if_t<Exts::rank() == 2, int>> {
120 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
121 using extents_t = std::experimental::dextents<idx_t, 2>;
124 constexpr auto operator()(std::vector<T>&
v, idx_t m, idx_t n)
const
128 return std::experimental::mdspan<T, extents_t>(
v.data(), m, n);
133 template <
class ET,
class Exts,
class LP,
class AP,
int m,
int n>
137 std::enable_if_t<Exts::rank() == 2, int>> {
138 static_assert(m >= 0 && n >= 0);
140 typename std::experimental::mdspan<ET, Exts, LP>::size_type;
141 using extents_t = std::experimental::extents<idx_t, m, n>;
143 template <
typename T>
146 return std::experimental::mdspan<T, extents_t>(
v);
155template <
class ET,
class Exts,
class LP,
class AP>
156constexpr auto size(
const std::experimental::mdspan<ET, Exts, LP, AP>&
x)
161template <
class ET,
class Exts,
class LP,
class AP>
162constexpr auto nrows(
const std::experimental::mdspan<ET, Exts, LP, AP>& x)
167template <
class ET,
class Exts,
class LP,
class AP>
168constexpr auto ncols(
const std::experimental::mdspan<ET, Exts, LP, AP>& x)
176#define isSlice(SliceSpec) \
177 std::is_convertible<SliceSpec, std::tuple<std::size_t, std::size_t>>::value
187 std::enable_if_t<isSlice(SliceSpecRow) || isSlice(SliceSpecCol),
int> = 0>
188constexpr auto slice(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
190 SliceSpecCol&& cols)
noexcept
192 return std::experimental::submdspan(A, std::forward<SliceSpecRow>(rows),
193 std::forward<SliceSpecCol>(cols));
202 std::enable_if_t<isSlice(SliceSpec),
int> = 0>
203constexpr auto rows(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
204 SliceSpec&& rows)
noexcept
206 return std::experimental::submdspan(A, std::forward<SliceSpec>(rows),
207 std::experimental::full_extent);
211template <
class ET,
class Exts,
class LP,
class AP>
212constexpr auto row(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
213 std::size_t rowIdx)
noexcept
215 return std::experimental::submdspan(A, rowIdx,
216 std::experimental::full_extent);
225 std::enable_if_t<isSlice(SliceSpec),
int> = 0>
226constexpr auto cols(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
227 SliceSpec&& cols)
noexcept
229 return std::experimental::submdspan(A, std::experimental::full_extent,
230 std::forward<SliceSpec>(cols));
234template <
class ET,
class Exts,
class LP,
class AP>
235constexpr auto col(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
236 std::size_t colIdx)
noexcept
238 return std::experimental::submdspan(A, std::experimental::full_extent,
248 std::enable_if_t<isSlice(SliceSpec) && (Exts::rank() == 1),
int> = 0>
249constexpr auto slice(
const std::experimental::mdspan<ET, Exts, LP, AP>& v,
250 SliceSpec&& rows)
noexcept
252 return std::experimental::submdspan(v, std::forward<SliceSpec>(rows));
262 LP::template mapping<Exts>::is_always_strided(),
264constexpr auto diag(
const std::experimental::mdspan<ET, Exts, LP, AP>& A,
269 using std::experimental::layout_stride;
272 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
273 using extents_t = std::experimental::dextents<size_type, 1>;
274 using mapping =
typename layout_stride::template mapping<extents_t>;
277 auto ptr = A.accessor().offset(A.data(), (diagIdx >= 0)
278 ? A.mapping()(0, diagIdx)
279 : A.mapping()(-diagIdx, 0));
283 ? min(A.extent(0) + diagIdx, A.extent(1)) - (size_type)diagIdx
284 : min(A.extent(0), A.extent(1) - diagIdx) + (
size_type)diagIdx),
285 array<
size_type, 1>{A.stride(0) + A.stride(1)});
286 auto acc_pol =
typename AP::offset_policy(A.accessor());
289 return std::experimental::mdspan<ET, extents_t, layout_stride,
290 typename AP::offset_policy>(
291 std::move(ptr), std::move(map), std::move(acc_pol));
295template <
class ET,
class Exts,
class AP>
296constexpr auto transpose_view(
297 const std::experimental::
298 mdspan<ET, Exts, std::experimental::layout_left, AP>& A) noexcept
301 std::experimental::mdspan<ET, Exts, std::experimental::layout_left, AP>;
302 using idx_t =
typename matrix_t::size_type;
304 std::experimental::extents<idx_t, matrix_t::static_extent(1),
305 matrix_t::static_extent(0)>;
307 using std::experimental::layout_right;
308 using mapping_t =
typename layout_right::template mapping<extents_t>;
310 mapping_t map(extents_t(A.extent(1), A.extent(0)));
311 return std::experimental::mdspan<ET, extents_t, layout_right, AP>(
312 A.data(), std::move(map));
314template <
class ET,
class Exts,
class AP>
315constexpr auto transpose_view(
316 const std::experimental::
317 mdspan<ET, Exts, std::experimental::layout_right, AP>& A) noexcept
320 std::experimental::mdspan<ET, Exts, std::experimental::layout_right,
322 using idx_t =
typename matrix_t::size_type;
324 std::experimental::extents<idx_t, matrix_t::static_extent(1),
325 matrix_t::static_extent(0)>;
327 using std::experimental::layout_left;
328 using mapping_t =
typename layout_left::template mapping<extents_t>;
330 mapping_t map(extents_t(A.extent(1), A.extent(0)));
331 return std::experimental::mdspan<ET, extents_t, layout_left, AP>(
332 A.data(), std::move(map));
334template <
class ET,
class Exts,
class AP>
335constexpr auto transpose_view(
336 const std::experimental::
337 mdspan<ET, Exts, std::experimental::layout_stride, AP>& A) noexcept
340 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride,
342 using idx_t =
typename matrix_t::size_type;
344 std::experimental::extents<idx_t, matrix_t::static_extent(1),
345 matrix_t::static_extent(0)>;
347 using std::experimental::layout_stride;
348 using mapping_t =
typename layout_stride::template mapping<extents_t>;
350 mapping_t map(extents_t(A.extent(1), A.extent(0)),
351 std::array<idx_t, 2>{A.stride(1), A.stride(0)});
352 return std::experimental::mdspan<ET, extents_t, layout_stride, AP>(
353 A.data(), std::move(map));
362 std::enable_if_t<(std::is_same_v<LP, std::experimental::layout_right> ||
363 std::is_same_v<LP, std::experimental::layout_left>),
365auto reshape(std::experimental::mdspan<ET, Exts, LP, AP>& A,
369 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
370 using extents1_t = std::experimental::dextents<idx_t, 1>;
371 using extents2_t = std::experimental::dextents<idx_t, 2>;
372 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
373 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
374 using mapping1_t =
typename LP::template mapping<extents1_t>;
375 using mapping2_t =
typename LP::template mapping<extents2_t>;
378 const idx_t size = A.size();
379 const idx_t new_size = m * n;
383 throw std::domain_error(
"New size is larger than current size");
385 return std::make_pair(
386 matrix_t(A.data(), mapping2_t(extents2_t(m, n))),
387 vector_t(A.data() + new_size, mapping1_t(extents1_t(size - new_size))));
392 std::enable_if_t<Exts::rank() == 2,
int> = 0>
394 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
399 using LP = std::experimental::layout_stride;
400 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
401 using extents_t = std::experimental::dextents<idx_t, 2>;
402 using matrix_t = std::experimental::mdspan<ET, extents_t, LP>;
403 using mapping_t =
typename LP::template mapping<extents_t>;
406 const idx_t size = A.size();
407 const idx_t new_size = m * n;
408 const bool is_contiguous =
411 (A.stride(1) == A.extent(0) || A.extent(1) <= 1)) ||
412 (A.stride(1) == 1 && (A.stride(0) == A.extent(1) || A.extent(0) <= 1));
416 throw std::domain_error(
"New size is larger than current size");
417 if (A.stride(0) != 1 && A.stride(1) != 1)
418 throw std::domain_error(
419 "Reshaping is not available for matrices with both strides "
420 "different from 1.");
423 const idx_t s = size - new_size;
424 if (A.stride(0) == 1)
425 return std::make_pair(
426 matrix_t(A.data(), mapping_t(extents_t(m, n),
427 std::array<idx_t, 2>{1, m})),
430 mapping_t(extents_t(s, 1), std::array<idx_t, 2>{1, s})));
432 return std::make_pair(
433 matrix_t(A.data(), mapping_t(extents_t(m, n),
434 std::array<idx_t, 2>{n, 1})),
437 mapping_t(extents_t(1, s), std::array<idx_t, 2>{s, 1})));
440 std::array<idx_t, 2> strides{A.stride(0), A.stride(1)};
442 if (m == A.extent(0) || n == 0) {
443 return std::make_pair(
444 matrix_t(A.data(), mapping_t(extents_t(m, n), strides)),
445 matrix_t(A.data() + n * A.stride(1),
446 mapping_t(extents_t(m, A.extent(1) - n), strides)));
448 else if (n == A.extent(1) || m == 0) {
449 return std::make_pair(
450 matrix_t(A.data(), mapping_t(extents_t(m, n), strides)),
451 matrix_t(A.data() + m * A.stride(0),
452 mapping_t(extents_t(A.extent(0) - m, n), strides)));
455 throw std::domain_error(
456 "Cannot reshape to non-contiguous matrix if the number of rows "
458 "columns are different.");
465 std::enable_if_t<Exts::rank() == 1,
int> = 0>
467 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
472 using LP = std::experimental::layout_stride;
473 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
474 using extents1_t = std::experimental::dextents<idx_t, 1>;
475 using extents2_t = std::experimental::dextents<idx_t, 2>;
476 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
477 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
478 using mapping1_t =
typename LP::template mapping<extents1_t>;
479 using mapping2_t =
typename LP::template mapping<extents2_t>;
482 const idx_t size = v.size();
483 const idx_t new_size = m * n;
484 const idx_t s = size - new_size;
485 const idx_t stride = v.stride(0);
486 const bool is_contiguous = (size <= 1 || stride == 1);
490 throw std::domain_error(
"New size is larger than current size");
491 if (!is_contiguous && m > 1 && n > 1)
492 throw std::domain_error(
493 "New sizes are not compatible with the current vector.");
496 return std::make_pair(
498 mapping2_t(extents2_t(m, n), std::array<idx_t, 2>{1, m})),
499 vector_t(v.data() + new_size,
500 mapping1_t(extents1_t(s), std::array<idx_t, 1>{1})));
503 return std::make_pair(
505 mapping2_t(extents2_t(m, n),
506 (m <= 1) ? std::array<idx_t, 2>{1, stride}
507 : std::array<idx_t, 2>{stride, 1})),
508 vector_t(v.data() + new_size * stride,
509 mapping1_t(extents1_t(s), std::array<idx_t, 1>{stride})));
519 std::enable_if_t<(std::is_same_v<LP, std::experimental::layout_right> ||
520 std::is_same_v<LP, std::experimental::layout_left>),
522auto reshape(std::experimental::mdspan<ET, Exts, LP, AP>& A, std::size_t n)
524 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
525 using extents_t = std::experimental::dextents<idx_t, 1>;
526 using vector_t = std::experimental::mdspan<ET, extents_t, LP>;
527 using mapping_t =
typename LP::template mapping<extents_t>;
530 const idx_t size = A.size();
534 throw std::domain_error(
"New size is larger than current size");
536 return std::make_pair(
537 vector_t(A.data(), mapping_t(extents_t(n))),
538 vector_t(A.data() + n, mapping_t(extents_t(size - n))));
543 std::enable_if_t<Exts::rank() == 2,
int> = 0>
545 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
549 using LP = std::experimental::layout_stride;
550 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
551 using extents1_t = std::experimental::dextents<idx_t, 1>;
552 using extents2_t = std::experimental::dextents<idx_t, 2>;
553 using vector_t = std::experimental::mdspan<ET, extents1_t, LP>;
554 using matrix_t = std::experimental::mdspan<ET, extents2_t, LP>;
555 using mapping1_t =
typename LP::template mapping<extents1_t>;
556 using mapping2_t =
typename LP::template mapping<extents2_t>;
559 const idx_t size = A.size();
560 const idx_t s = size - n;
561 const bool is_contiguous =
564 (A.stride(1) == A.extent(0) || A.extent(1) <= 1)) ||
565 (A.stride(1) == 1 && (A.stride(0) == A.extent(1) || A.extent(0) <= 1));
569 throw std::domain_error(
"New size is larger than current size");
570 if (A.stride(0) != 1 && A.stride(1) != 1)
571 throw std::domain_error(
572 "Reshaping is not available for matrices with both strides "
573 "different from 1.");
576 return std::make_pair(
578 mapping1_t(extents1_t(n), std::array<idx_t, 1>{1})),
582 ? mapping2_t(extents2_t(s, 1), std::array<idx_t, 2>{1, s})
583 : mapping2_t(extents2_t(1, s),
584 std::array<idx_t, 2>{s, 1})));
587 std::array<idx_t, 2> strides{A.stride(0), A.stride(1)};
590 return std::make_pair(
592 mapping1_t(extents1_t(0), std::array<idx_t, 1>{1})),
593 matrix_t(A.data(), mapping2_t(A.extents(), strides)));
595 else if (n == A.extent(0)) {
596 return std::make_pair(
598 mapping1_t(extents1_t(n),
599 std::array<idx_t, 1>{A.stride(0)})),
600 matrix_t(A.data() + A.stride(1),
601 mapping2_t(extents2_t(A.extent(0), A.extent(1) - 1),
604 else if (n == A.extent(1)) {
605 return std::make_pair(
607 mapping1_t(extents1_t(n),
608 std::array<idx_t, 1>{A.stride(1)})),
609 matrix_t(A.data() + A.stride(0),
610 mapping2_t(extents2_t(A.extent(0) - 1, A.extent(1)),
614 throw std::domain_error(
615 "Cannot reshape to non-contiguous matrix if the number of rows "
617 "columns are different.");
624 std::enable_if_t<Exts::rank() == 1,
int> = 0>
626 std::experimental::mdspan<ET, Exts, std::experimental::layout_stride, AP>&
630 using LP = std::experimental::layout_stride;
631 using idx_t =
typename std::experimental::mdspan<ET, Exts, LP>::size_type;
632 using extents_t = std::experimental::dextents<idx_t, 1>;
633 using vector_t = std::experimental::mdspan<ET, extents_t, LP>;
634 using mapping_t =
typename LP::template mapping<extents_t>;
637 const std::array<idx_t, 1> stride{v.stride(0)};
641 throw std::domain_error(
"New size is larger than current size");
643 return std::make_pair(
644 vector_t(v.data(), mapping_t(extents_t(n), stride)),
645 vector_t(v.data() + n, mapping_t(extents_t(v.size() - n), stride)));
655 template <
typename T>
656 constexpr bool cast_to_mdspan_type =
657 is_mdspan_type<T> || is_stdvector_type<T>
658#ifdef TLAPACK_EIGEN_HH
661#ifdef TLAPACK_LEGACYARRAY_HH
668 template <
class matrixA_t,
class matrixB_t>
673 is_mdspan_type<matrixB_t> &&
674 (layout<matrixA_t> == layout<matrixB_t>),
678 using extents_t = std::experimental::dextents<idx_t, 2>;
680 using type = std::experimental::
681 mdspan<T, extents_t, typename matrixA_t::layout_type>;
683 template <
class matrixA_t,
class matrixB_t>
689 (is_mdspan_type<matrixA_t> && is_mdspan_type<matrixB_t> &&
690 !(layout<matrixA_t> == layout<matrixB_t>))
694 ((is_mdspan_type<matrixA_t> || is_mdspan_type<matrixB_t>)&&(
695 !is_mdspan_type<matrixA_t> ||
697 matrixB_t>)&&cast_to_mdspan_type<matrixA_t> &&
698 cast_to_mdspan_type<matrixB_t>),
702 using extents_t = std::experimental::dextents<idx_t, 2>;
704 using type = std::experimental::
705 mdspan<T, extents_t, std::experimental::layout_stride>;
710 template <
class matrixA_t,
class matrixB_t>
715 is_mdspan_type<matrixB_t> &&
716 (layout<matrixA_t> == layout<matrixB_t>),
720 using extents_t = std::experimental::dextents<idx_t, 1>;
722 using type = std::experimental::
723 mdspan<T, extents_t, typename matrixA_t::layout_type>;
725 template <
class matrixA_t,
class matrixB_t>
731 (is_mdspan_type<matrixA_t> && is_mdspan_type<matrixB_t> &&
732 !(layout<matrixA_t> == layout<matrixB_t>))
736 ((is_mdspan_type<matrixA_t> || is_mdspan_type<matrixB_t>)&&(
737 !is_mdspan_type<matrixA_t> ||
739 matrixB_t>)&&cast_to_mdspan_type<matrixA_t> &&
740 cast_to_mdspan_type<matrixB_t>),
744 using extents_t = std::experimental::dextents<idx_t, 1>;
746 using type = std::experimental::
747 mdspan<T, extents_t, std::experimental::layout_stride>;
750#if !defined(TLAPACK_EIGEN_HH) && !defined(TLAPACK_LEGACYARRAY_HH)
751 template <
class vecA_t,
class vecB_t>
755 std::enable_if_t<traits::is_stdvector_type<vecA_t> &&
756 traits::is_stdvector_type<vecB_t>,
759 using extents_t = std::experimental::dextents<std::size_t, 2>;
761 using type = std::experimental::
762 mdspan<T, extents_t, std::experimental::layout_left>;
765 template <
class vecA_t,
class vecB_t>
766 struct vector_type_traits<
769 std::enable_if_t<traits::is_stdvector_type<vecA_t> &&
770 traits::is_stdvector_type<vecB_t>,
773 using extents_t = std::experimental::dextents<std::size_t, 1>;
775 using type = std::experimental::
776 mdspan<T, extents_t, std::experimental::layout_left>;
789 std::enable_if_t<Exts::rank() == 2 &&
790 LP::template mapping<Exts>::is_always_strided(),
792constexpr auto legacy_matrix(
793 const std::experimental::mdspan<ET, Exts, LP, AP>& A)
noexcept
796 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
801 const Layout L = (A.stride(0) == 1 && A.stride(1) >= A.extent(0))
805 assert((A.stride(0) == 1 && A.stride(1) >= A.extent(0)) ||
806 (A.stride(1) == 1 && A.stride(0) >= A.extent(1)));
808 return legacy::Matrix<ET, idx_t>{
809 L, A.extent(0), A.extent(1), A.data(),
810 (L == Layout::ColMajor) ? A.stride(1) : A.stride(0)};
817 std::enable_if_t<Exts::rank() == 1 &&
818 LP::template mapping<Exts>::is_always_strided(),
820constexpr auto legacy_matrix(
821 const std::experimental::mdspan<ET, Exts, LP, AP>& A)
noexcept
824 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
825 return legacy::Matrix<ET, idx_t>{Layout::ColMajor, 1, A.size(), A.data(),
833 std::enable_if_t<Exts::rank() == 1 &&
834 LP::template mapping<Exts>::is_always_strided(),
836constexpr auto legacy_vector(
837 const std::experimental::mdspan<ET, Exts, LP, AP>& A)
noexcept
840 typename std::experimental::mdspan<ET, Exts, LP, AP>::size_type;
841 return legacy::Vector<ET, idx_t>{A.size(), A.data(), A.stride(0)};
typename traits::size_type_trait< T, int >::type size_type
Size type of a matrix or vector.
Definition arrayTraits.hpp:228
Layout
Definition types.hpp:29
@ RowMajor
Row-major layout.
constexpr auto diag(T &A, int diagIdx=0) noexcept
Get the Diagonal of an Eigen Matrix.
Definition eigen.hpp:576
constexpr bool is_mdspan_type
True if T is a mdspan array.
Definition mdspan.hpp:38
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Functor for data creation.
Definition arrayTraits.hpp:89
constexpr auto operator()(std::vector< T > &v, idx_t m, idx_t n=1) const
Creates a m-by-n matrix with entries of type T.
Definition arrayTraits.hpp:105
Functor for data creation with static size.
Definition arrayTraits.hpp:141
constexpr auto operator()(T *v) const
Creates a m-by-n matrix or, if n == -1, a vector of size m.
Definition arrayTraits.hpp:157
Complex type traits for the list of types Types.
Definition scalar_type_traits.hpp:145
Trait to determine the layout of a given data structure.
Definition arrayTraits.hpp:75
Matrix type deduction.
Definition arrayTraits.hpp:176
Real type traits for the list of types Types.
Definition scalar_type_traits.hpp:71
Vector type deduction.
Definition arrayTraits.hpp:203