<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
Matrix.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2025, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_STARPU_MATRIX_HH
11#define TLAPACK_STARPU_MATRIX_HH
12
13#include <starpu.h>
14
15#include <array>
16#include <iomanip>
17#include <memory>
18#include <ostream>
19
23
24namespace tlapack {
25namespace starpu {
26
27 namespace internal {
28
39 template <class T, bool TisConstType>
41
42 // EntryAccess for const types
43 template <class T>
44 struct EntryAccess<T, true> {
45 // abstract interface
46 virtual idx_t nrows() const noexcept = 0;
47 virtual idx_t ncols() const noexcept = 0;
48 virtual MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept = 0;
49
58 constexpr T operator()(idx_t i, idx_t j) const noexcept
59 {
60 auto& A = const_cast<EntryAccess<T, true>&>(*this);
61 return T(A.map_to_entry(i, j));
62 }
63
71 constexpr T operator[](idx_t i) const noexcept
72 {
73 assert((nrows() <= 1 || ncols() <= 1) &&
74 "Matrix is not a vector");
75 return (nrows() > 1) ? (*this)(i, 0) : (*this)(0, i);
76 }
77 };
78
79 // EntryAccess for non-const types
80 template <class T>
81 struct EntryAccess<T, false> : public EntryAccess<T, true> {
82 // abstract interface
83 virtual idx_t nrows() const noexcept = 0;
84 virtual idx_t ncols() const noexcept = 0;
85 virtual MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept = 0;
86
88 using EntryAccess<T, true>::operator();
89
91 using EntryAccess<T, true>::operator[];
92
101 constexpr MatrixEntry<T> operator()(idx_t i, idx_t j) noexcept
102 {
103 return map_to_entry(i, j);
104 }
105
113 constexpr MatrixEntry<T> operator[](idx_t i) noexcept
114 {
115 assert((nrows() <= 1 || ncols() <= 1) &&
116 "Matrix is not a vector");
117 return (nrows() > 1) ? (*this)(i, 0) : (*this)(0, i);
118 }
119 };
120
121 } // namespace internal
122
133 template <class T>
134 class Matrix : public internal::EntryAccess<T, std::is_const_v<T>> {
135 public:
136 using internal::EntryAccess<T, std::is_const_v<T>>::operator();
137 using internal::EntryAccess<T, std::is_const_v<T>>::operator[];
138
139 // ---------------------------------------------------------------------
140 // Constructors and destructor
141
151 Matrix(T* ptr, idx_t m, idx_t n, idx_t ld, idx_t mt, idx_t nt) noexcept
152 : pHandle(new starpu_data_handle_t(), [](starpu_data_handle_t* h) {
155 delete h;
156 })
157 {
158 assert(m > 0 && n > 0 && "Invalid matrix size");
159 assert(mt <= m && nt <= n && "Invalid tile size");
160
161 nx = (mt == 0) ? 1 : ((m % mt == 0) ? m / mt : m / mt + 1);
162 ny = (nt == 0) ? 1 : ((n % nt == 0) ? n / nt : n / nt + 1);
163
165 (uintptr_t)ptr, ld, m, n, sizeof(T));
166 create_grid(mt, nt);
167
169 starpu_data_get_child(*pHandle, nx - 1), ny - 1);
170 lastRows = starpu_matrix_get_nx(handleN);
171 lastCols = starpu_matrix_get_ny(handleN);
172 }
173
178 constexpr Matrix(T* ptr, idx_t m, idx_t n, idx_t mt, idx_t nt) noexcept
179 : Matrix(ptr, m, n, m, mt, nt)
180 {}
181
183 constexpr Matrix(const std::shared_ptr<starpu_data_handle_t>& pHandle,
184 idx_t ix,
185 idx_t iy,
186 idx_t nx,
187 idx_t ny,
188 idx_t row0,
189 idx_t col0,
190 idx_t lastRows,
191 idx_t lastCols) noexcept
192 : pHandle(pHandle),
193 ix(ix),
194 iy(iy),
195 nx(nx),
196 ny(ny),
197 row0(row0),
198 col0(col0),
199 lastRows(lastRows),
200 lastCols(lastCols)
201 {
202 assert(ix >= 0 && iy >= 0 && "Invalid tile position");
203 assert(nx > 0 && ny > 0 && "Invalid tile size");
204 assert(row0 >= 0 && col0 >= 0 && "Invalid tile offset");
205 assert(lastRows >= 0 && lastCols >= 0 && "Invalid tile size");
206 }
207
208 // Disable copy assignment operator
209 Matrix& operator=(const Matrix&) = delete;
210
211 // ---------------------------------------------------------------------
212 // Getters
213
215 constexpr idx_t get_nx() const noexcept { return nx; }
216
218 constexpr idx_t get_ny() const noexcept { return ny; }
219
221 constexpr idx_t nblockrows() const noexcept
222 {
223 return starpu_matrix_get_nx(starpu_data_get_child(*pHandle, 0));
224 }
225
227 constexpr idx_t nblockcols() const noexcept
228 {
231 }
232
235 Tile tile(idx_t ix, idx_t iy) noexcept
236 {
237 assert(ix >= 0 && iy >= 0 && ix < nx && iy < ny &&
238 "Invalid tile index");
239
241 *pHandle, 2, ix + this->ix, iy + this->iy);
242
243 // Collect information about the tile
244 idx_t i = 0, j = 0;
247 if (ix == 0) {
248 i = row0;
249 m -= i;
250 }
251 if (iy == 0) {
252 j = col0;
253 n -= j;
254 }
255 if (ix == nx - 1) m = lastRows;
256 if (iy == ny - 1) n = lastCols;
257
258 return Tile(tile_handle, i, j, m, n);
259 }
260
266 constexpr idx_t nrows() const noexcept override
267 {
268 const idx_t mb = nblockrows();
269 if (nx == 1) return lastRows;
270 if (nx == 2) return (mb - row0) + lastRows;
271 return (mb - row0) + (nx - 2) * mb + lastRows;
272 }
273
279 constexpr idx_t ncols() const noexcept override
280 {
281 const idx_t nb = nblockcols();
282 if (ny <= 1) return lastCols;
283 if (ny <= 2) return (nb - col0) + lastCols;
284 return (nb - col0) + (ny - 2) * nb + lastCols;
285 }
286
287 // ---------------------------------------------------------------------
288 // Submatrix creation
289
299 Matrix<T> get_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) noexcept
300 {
301 const auto [row0, col0, lastRows, lastCols] =
302 _get_tiles_info(ix, iy, nx, ny);
303
304 if (nx == 0) nx = 1;
305 if (ny == 0) ny = 1;
306
307 return Matrix<T>(pHandle, this->ix + ix, this->iy + iy, nx, ny,
308 row0, col0, lastRows, lastCols);
309 }
310
320 idx_t rowEnd,
321 idx_t colStart,
322 idx_t colEnd) noexcept
323 {
324 const auto [ix, iy, nx, ny, row0, col0, lastRows, lastCols] =
325 _map_to_tiles(rowStart, rowEnd, colStart, colEnd);
326
327 return Matrix<T>(pHandle, this->ix + ix, this->iy + iy, nx, ny,
328 row0, col0, lastRows, lastCols);
329 }
330
342 idx_t iy,
343 idx_t nx,
344 idx_t ny) const noexcept
345 {
346 const auto [row0, col0, lastRows, lastCols] =
347 _get_tiles_info(ix, iy, nx, ny);
348
349 if (nx == 0) nx = 1;
350 if (ny == 0) ny = 1;
351
352 return Matrix<const T>(pHandle, this->ix + ix, this->iy + iy, nx,
353 ny, row0, col0, lastRows, lastCols);
354 }
355
365 idx_t rowEnd,
366 idx_t colStart,
367 idx_t colEnd) const noexcept
368 {
369 const auto [ix, iy, nx, ny, row0, col0, lastRows, lastCols] =
370 _map_to_tiles(rowStart, rowEnd, colStart, colEnd);
371
372 return Matrix<const T>(pHandle, this->ix + ix, this->iy + iy, nx,
373 ny, row0, col0, lastRows, lastCols);
374 }
375
376 // ---------------------------------------------------------------------
377 // Display matrix in output stream
378
385 friend std::ostream& operator<<(std::ostream& out,
386 const starpu::Matrix<T>& A)
387 {
388 out << "starpu::Matrix<" << typeid(T).name()
389 << ">( nrows = " << A.nrows() << ", ncols = " << A.ncols()
390 << " )";
391 if (A.ncols() <= 10) {
392 out << std::scientific << std::setprecision(2) << "\n";
393 for (idx_t i = 0; i < A.nrows(); ++i) {
394 for (idx_t j = 0; j < A.ncols(); ++j) {
395 const T number = A(i, j);
396 if (abs(number) == -number) out << " ";
397 out << number << " ";
398 }
399 out << "\n";
400 }
401 }
402 return out;
403 }
404
405 private:
406 std::shared_ptr<starpu_data_handle_t> pHandle;
407
408 // Position in the grid
409 idx_t ix = 0;
410 idx_t iy = 0;
411 idx_t nx = 1;
412 idx_t ny = 1;
413
414 // Position in the first and last tiles of the grid
415 idx_t row0 = 0;
416 idx_t col0 = 0;
417 idx_t lastRows = 0;
418 idx_t lastCols = 0;
419
431 void create_grid(idx_t mt, idx_t nt) noexcept
432 {
433 /* Split into blocks of complete rows first */
434 const struct starpu_data_filter row_split = {
435 .filter_func = filter_rows, .nchildren = nx, .filter_arg = mt};
436
437 /* Then split rows into tiles */
438 const struct starpu_data_filter col_split = {
439 .filter_func = filter_cols, .nchildren = ny, .filter_arg = nt};
440
441 starpu_data_map_filters(*pHandle, 2, &row_split, &col_split);
442 }
443
452 MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept override
453 {
454 const idx_t mb = nblockrows();
455 const idx_t nb = nblockcols();
456
457 assert((i >= 0 && i < nrows()) && "Row index out of bounds");
458 assert((j >= 0 && j < ncols()) && "Column index out of bounds");
459
460 const idx_t ix = (i + row0) / mb;
461 const idx_t iy = (j + col0) / nb;
462 const idx_t row = (i + row0) % mb;
463 const idx_t col = (j + col0) % nb;
464
465 const idx_t pos[2] = {row, col};
466
467 return MatrixEntry<T>(
468 starpu_data_get_sub_data(*pHandle, 2, ix + this->ix,
469 iy + this->iy),
470 pos);
471 }
472
485 std::array<idx_t, 4> _get_tiles_info(idx_t ix,
486 idx_t iy,
487 idx_t nx,
488 idx_t ny) const noexcept
489 {
490 assert(ix >= 0 && iy >= 0 && ix + nx <= this->nx &&
491 iy + ny <= this->ny && "Invalid tile indices");
492 assert(nx >= 0 && ny >= 0 && "Invalid number of tiles");
493
494 const idx_t mb = nblockrows();
495 const idx_t nb = nblockcols();
496
497 const idx_t row0 = (ix == 0) ? this->row0 : 0;
498 const idx_t col0 = (iy == 0) ? this->col0 : 0;
499
500 idx_t lastRows;
501 if (nx == 0)
502 lastRows = 0;
503 else if (ix + nx == this->nx)
504 lastRows = this->lastRows;
505 else if (ix + nx == 1)
506 lastRows = mb - row0;
507 else
508 lastRows = mb;
509
510 idx_t lastCols;
511 if (ny == 0)
512 lastCols = 0;
513 else if (iy + ny == this->ny)
514 lastCols = this->lastCols;
515 else if (iy + ny == 1)
516 lastCols = nb - col0;
517 else
518 lastCols = nb;
519
520 return {row0, col0, lastRows, lastCols};
521 }
522
534 std::array<idx_t, 8> _map_to_tiles(idx_t rowStart,
535 idx_t rowEnd,
536 idx_t colStart,
537 idx_t colEnd) const noexcept
538 {
539 const idx_t mb = this->nblockrows();
540 const idx_t nb = this->nblockcols();
541 const idx_t nrows = rowEnd - rowStart;
542 const idx_t ncols = colEnd - colStart;
543
544 assert(rowStart >= 0 && colStart >= 0 &&
545 "Submatrix starts before the beginning of the matrix");
546 assert(rowEnd >= rowStart && colEnd >= colStart &&
547 "Submatrix has negative dimensions");
548 assert(rowEnd <= this->nrows() && colEnd <= this->ncols() &&
549 "Submatrix ends after the end of the matrix");
550
551 const idx_t ix = (rowStart + this->row0) / mb;
552 const idx_t iy = (colStart + this->col0) / nb;
553 const idx_t row0 = (rowStart + this->row0) % mb;
554 const idx_t col0 = (colStart + this->col0) % nb;
555
556 const idx_t nx = (nrows == 0) ? 1 : (row0 + nrows - 1) / mb + 1;
557 const idx_t ny = (ncols == 0) ? 1 : (col0 + ncols - 1) / nb + 1;
558
559 const idx_t lastRows =
560 (nx == 1) ? nrows : (row0 + nrows - 1) % mb + 1;
561 const idx_t lastCols =
562 (ny == 1) ? ncols : (col0 + ncols - 1) % nb + 1;
563
564 return {ix, iy, nx, ny, row0, col0, lastRows, lastCols};
565 }
566 };
567
568} // namespace starpu
569} // namespace tlapack
570
571#endif // TLAPACK_STARPU_MATRIX_HH
Class for representing a matrix in StarPU that is split into tiles.
Definition Matrix.hpp:134
Matrix(T *ptr, idx_t m, idx_t n, idx_t ld, idx_t mt, idx_t nt) noexcept
Create a matrix of size m-by-n from a pointer in main memory.
Definition Matrix.hpp:151
constexpr idx_t get_ny() const noexcept
Get number of tiles in y direction.
Definition Matrix.hpp:218
Matrix< T > map_to_tiles(idx_t rowStart, idx_t rowEnd, idx_t colStart, idx_t colEnd) noexcept
Create a submatrix from starting and ending indices.
Definition Matrix.hpp:319
constexpr Matrix(T *ptr, idx_t m, idx_t n, idx_t mt, idx_t nt) noexcept
Create a matrix of size m-by-n from contiguous data in main memory.
Definition Matrix.hpp:178
constexpr idx_t ncols() const noexcept override
Get the number of columns in the matrix.
Definition Matrix.hpp:279
constexpr idx_t get_nx() const noexcept
Get number of tiles in x direction.
Definition Matrix.hpp:215
Matrix< const T > get_const_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) const noexcept
Create a const submatrix from a list of tiles.
Definition Matrix.hpp:341
Matrix< T > get_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) noexcept
Create a submatrix from a list of tiles.
Definition Matrix.hpp:299
constexpr idx_t nblockcols() const noexcept
Get the maximum number of columns of a tile.
Definition Matrix.hpp:227
constexpr idx_t nrows() const noexcept override
Get the number of rows in the matrix.
Definition Matrix.hpp:266
Matrix< const T > map_to_const_tiles(idx_t rowStart, idx_t rowEnd, idx_t colStart, idx_t colEnd) const noexcept
Create a const submatrix from starting and ending indices.
Definition Matrix.hpp:364
Tile tile(idx_t ix, idx_t iy) noexcept
Get the data handle of a tile in the matrix or the data handle of the matrix if it is not partitioned...
Definition Matrix.hpp:235
friend std::ostream & operator<<(std::ostream &out, const starpu::Matrix< T > &A)
Display matrix in output stream.
Definition Matrix.hpp:385
constexpr idx_t nblockrows() const noexcept
Get the maximum number of rows of a tile.
Definition Matrix.hpp:221
constexpr Matrix(const std::shared_ptr< starpu_data_handle_t > &pHandle, idx_t ix, idx_t iy, idx_t nx, idx_t ny, idx_t row0, idx_t col0, idx_t lastRows, idx_t lastCols) noexcept
Create a submatrix from a handle and a grid.
Definition Matrix.hpp:183
Filters for StarPU data interfaces.
void filter_rows(void *father_interface, void *child_interface, STARPU_ATTRIBUTE_UNUSED struct starpu_data_filter *f, unsigned id, unsigned nparts) noexcept
StarPU filter to partition a matrix along the x (row) dimension.
Definition filters.hpp:105
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Class for representing a tile of a matrix.
Definition Tile.hpp:25
constexpr MatrixEntry< T > operator()(idx_t i, idx_t j) noexcept
Returns a reference to an element of the matrix.
Definition Matrix.hpp:101
constexpr MatrixEntry< T > operator[](idx_t i) noexcept
Returns a reference to an element of the vector.
Definition Matrix.hpp:113
constexpr T operator()(idx_t i, idx_t j) const noexcept
Returns an element of the matrix.
Definition Matrix.hpp:58
constexpr T operator[](idx_t i) const noexcept
Returns an element of the vector.
Definition Matrix.hpp:71
Class for accessing the elements of a tlapack::starpu::Matrix.
Definition Matrix.hpp:40