<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
Matrix.hpp
Go to the documentation of this file.
1
3//
4// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
5//
6// This file is part of <T>LAPACK.
7// <T>LAPACK is free software: you can redistribute it and/or modify it under
8// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
9
10#ifndef TLAPACK_STARPU_MATRIX_HH
11#define TLAPACK_STARPU_MATRIX_HH
12
13#include <starpu.h>
14
15#include <iomanip>
16#include <memory>
17#include <ostream>
18
22
23namespace tlapack {
24namespace starpu {
25
26 namespace internal {
27
38 template <class T, bool TisConstType>
40
41 // EntryAccess for const types
42 template <class T>
43 struct EntryAccess<T, true> {
44 // abstract interface
45 virtual idx_t nrows() const noexcept = 0;
46 virtual idx_t ncols() const noexcept = 0;
47 virtual MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept = 0;
48
57 constexpr T operator()(idx_t i, idx_t j) const noexcept
58 {
59 auto& A = const_cast<EntryAccess<T, true>&>(*this);
60 return T(A.map_to_entry(i, j));
61 }
62
70 constexpr T operator[](idx_t i) const noexcept
71 {
72 assert((nrows() <= 1 || ncols() <= 1) &&
73 "Matrix is not a vector");
74 return (nrows() > 1) ? (*this)(i, 0) : (*this)(0, i);
75 }
76 };
77
78 // EntryAccess for non-const types
79 template <class T>
80 struct EntryAccess<T, false> : public EntryAccess<T, true> {
81 // abstract interface
82 virtual idx_t nrows() const noexcept = 0;
83 virtual idx_t ncols() const noexcept = 0;
84 virtual MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept = 0;
85
87 using EntryAccess<T, true>::operator();
88
90 using EntryAccess<T, true>::operator[];
91
100 constexpr MatrixEntry<T> operator()(idx_t i, idx_t j) noexcept
101 {
102 return map_to_entry(i, j);
103 }
104
112 constexpr MatrixEntry<T> operator[](idx_t i) noexcept
113 {
114 assert((nrows() <= 1 || ncols() <= 1) &&
115 "Matrix is not a vector");
116 return (nrows() > 1) ? (*this)(i, 0) : (*this)(0, i);
117 }
118 };
119
120 } // namespace internal
121
132 template <class T>
133 class Matrix : public internal::EntryAccess<T, std::is_const_v<T>> {
134 public:
135 using internal::EntryAccess<T, std::is_const_v<T>>::operator();
136 using internal::EntryAccess<T, std::is_const_v<T>>::operator[];
137
138 // ---------------------------------------------------------------------
139 // Constructors and destructor
140
150 Matrix(T* ptr, idx_t m, idx_t n, idx_t ld, idx_t mt, idx_t nt) noexcept
151 : pHandle(new starpu_data_handle_t(), [](starpu_data_handle_t* h) {
154 delete h;
155 })
156 {
157 assert(m > 0 && n > 0 && "Invalid matrix size");
158 assert(mt <= m && nt <= n && "Invalid tile size");
159
160 nx = (mt == 0) ? 1 : ((m % mt == 0) ? m / mt : m / mt + 1);
161 ny = (nt == 0) ? 1 : ((n % nt == 0) ? n / nt : n / nt + 1);
162
164 (uintptr_t)ptr, ld, m, n, sizeof(T));
165 create_grid(mt, nt);
166
168 starpu_data_get_child(*pHandle, nx - 1), ny - 1);
169 lastRows = starpu_matrix_get_nx(handleN);
170 lastCols = starpu_matrix_get_ny(handleN);
171 }
172
177 constexpr Matrix(T* ptr, idx_t m, idx_t n, idx_t mt, idx_t nt) noexcept
178 : Matrix(ptr, m, n, m, mt, nt)
179 {}
180
182 constexpr Matrix(const std::shared_ptr<starpu_data_handle_t>& pHandle,
183 idx_t ix,
184 idx_t iy,
185 idx_t nx,
186 idx_t ny,
187 idx_t row0,
188 idx_t col0,
189 idx_t lastRows,
190 idx_t lastCols) noexcept
191 : pHandle(pHandle),
192 ix(ix),
193 iy(iy),
194 nx(nx),
195 ny(ny),
196 row0(row0),
197 col0(col0),
198 lastRows(lastRows),
199 lastCols(lastCols)
200 {
201 assert(ix >= 0 && iy >= 0 && "Invalid tile position");
202 assert(nx > 0 && ny > 0 && "Invalid tile size");
203 assert(row0 >= 0 && col0 >= 0 && "Invalid tile offset");
204 assert(lastRows >= 0 && lastCols >= 0 && "Invalid tile size");
205 }
206
207 // Disable copy assignment operator
208 Matrix& operator=(const Matrix&) = delete;
209
210 // ---------------------------------------------------------------------
211 // Getters
212
214 constexpr idx_t get_nx() const noexcept { return nx; }
215
217 constexpr idx_t get_ny() const noexcept { return ny; }
218
220 constexpr idx_t nblockrows() const noexcept
221 {
222 return starpu_matrix_get_nx(starpu_data_get_child(*pHandle, 0));
223 }
224
226 constexpr idx_t nblockcols() const noexcept
227 {
230 }
231
234 Tile tile(idx_t ix, idx_t iy) noexcept
235 {
236 assert(ix >= 0 && iy >= 0 && ix < nx && iy < ny &&
237 "Invalid tile index");
238
240 *pHandle, 2, ix + this->ix, iy + this->iy);
241
242 // Collect information about the tile
243 idx_t i = 0, j = 0;
246 if (ix == 0) {
247 i = row0;
248 m -= i;
249 }
250 if (iy == 0) {
251 j = col0;
252 n -= j;
253 }
254 if (ix == nx - 1) m = lastRows;
255 if (iy == ny - 1) n = lastCols;
256
257 return Tile(tile_handle, i, j, m, n);
258 }
259
265 constexpr idx_t nrows() const noexcept override
266 {
267 const idx_t mb = nblockrows();
268 if (nx == 1) return lastRows;
269 if (nx == 2) return (mb - row0) + lastRows;
270 return (mb - row0) + (nx - 2) * mb + lastRows;
271 }
272
278 constexpr idx_t ncols() const noexcept override
279 {
280 const idx_t nb = nblockcols();
281 if (ny <= 1) return lastCols;
282 if (ny <= 2) return (nb - col0) + lastCols;
283 return (nb - col0) + (ny - 2) * nb + lastCols;
284 }
285
286 // ---------------------------------------------------------------------
287 // Submatrix creation
288
298 Matrix<T> get_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) noexcept
299 {
300 const auto [row0, col0, lastRows, lastCols] =
301 _get_tiles_info(ix, iy, nx, ny);
302
303 if (nx == 0) nx = 1;
304 if (ny == 0) ny = 1;
305
306 return Matrix<T>(pHandle, this->ix + ix, this->iy + iy, nx, ny,
307 row0, col0, lastRows, lastCols);
308 }
309
319 idx_t rowEnd,
320 idx_t colStart,
321 idx_t colEnd) noexcept
322 {
323 const auto [ix, iy, nx, ny, row0, col0, lastRows, lastCols] =
324 _map_to_tiles(rowStart, rowEnd, colStart, colEnd);
325
326 return Matrix<T>(pHandle, this->ix + ix, this->iy + iy, nx, ny,
327 row0, col0, lastRows, lastCols);
328 }
329
341 idx_t iy,
342 idx_t nx,
343 idx_t ny) const noexcept
344 {
345 const auto [row0, col0, lastRows, lastCols] =
346 _get_tiles_info(ix, iy, nx, ny);
347
348 if (nx == 0) nx = 1;
349 if (ny == 0) ny = 1;
350
351 return Matrix<const T>(pHandle, this->ix + ix, this->iy + iy, nx,
352 ny, row0, col0, lastRows, lastCols);
353 }
354
364 idx_t rowEnd,
365 idx_t colStart,
366 idx_t colEnd) const noexcept
367 {
368 const auto [ix, iy, nx, ny, row0, col0, lastRows, lastCols] =
369 _map_to_tiles(rowStart, rowEnd, colStart, colEnd);
370
371 return Matrix<const T>(pHandle, this->ix + ix, this->iy + iy, nx,
372 ny, row0, col0, lastRows, lastCols);
373 }
374
375 // ---------------------------------------------------------------------
376 // Display matrix in output stream
377
384 friend std::ostream& operator<<(std::ostream& out,
385 const starpu::Matrix<T>& A)
386 {
387 out << "starpu::Matrix<" << typeid(T).name()
388 << ">( nrows = " << A.nrows() << ", ncols = " << A.ncols()
389 << " )";
390 if (A.ncols() <= 10) {
391 out << std::scientific << std::setprecision(2) << "\n";
392 for (idx_t i = 0; i < A.nrows(); ++i) {
393 for (idx_t j = 0; j < A.ncols(); ++j) {
394 const T number = A(i, j);
395 if (abs(number) == -number) out << " ";
396 out << number << " ";
397 }
398 out << "\n";
399 }
400 }
401 return out;
402 }
403
404 private:
405 std::shared_ptr<starpu_data_handle_t> pHandle;
406
407 // Position in the grid
408 idx_t ix = 0;
409 idx_t iy = 0;
410 idx_t nx = 1;
411 idx_t ny = 1;
412
413 // Position in the first and last tiles of the grid
414 idx_t row0 = 0;
415 idx_t col0 = 0;
416 idx_t lastRows = 0;
417 idx_t lastCols = 0;
418
430 void create_grid(idx_t mt, idx_t nt) noexcept
431 {
432 /* Split into blocks of complete rows first */
433 const struct starpu_data_filter row_split = {
434 .filter_func = filter_rows, .nchildren = nx, .filter_arg = mt};
435
436 /* Then split rows into tiles */
437 const struct starpu_data_filter col_split = {
438 .filter_func = filter_cols, .nchildren = ny, .filter_arg = nt};
439
440 starpu_data_map_filters(*pHandle, 2, &row_split, &col_split);
441 }
442
451 MatrixEntry<T> map_to_entry(idx_t i, idx_t j) noexcept override
452 {
453 const idx_t mb = nblockrows();
454 const idx_t nb = nblockcols();
455
456 assert((i >= 0 && i < nrows()) && "Row index out of bounds");
457 assert((j >= 0 && j < ncols()) && "Column index out of bounds");
458
459 const idx_t ix = (i + row0) / mb;
460 const idx_t iy = (j + col0) / nb;
461 const idx_t row = (i + row0) % mb;
462 const idx_t col = (j + col0) % nb;
463
464 const idx_t pos[2] = {row, col};
465
466 return MatrixEntry<T>(
467 starpu_data_get_sub_data(*pHandle, 2, ix + this->ix,
468 iy + this->iy),
469 pos);
470 }
471
484 std::array<idx_t, 4> _get_tiles_info(idx_t ix,
485 idx_t iy,
486 idx_t nx,
487 idx_t ny) const noexcept
488 {
489 assert(ix >= 0 && iy >= 0 && ix + nx <= this->nx &&
490 iy + ny <= this->ny && "Invalid tile indices");
491 assert(nx >= 0 && ny >= 0 && "Invalid number of tiles");
492
493 const idx_t mb = nblockrows();
494 const idx_t nb = nblockcols();
495
496 const idx_t row0 = (ix == 0) ? this->row0 : 0;
497 const idx_t col0 = (iy == 0) ? this->col0 : 0;
498
499 idx_t lastRows;
500 if (nx == 0)
501 lastRows = 0;
502 else if (ix + nx == this->nx)
503 lastRows = this->lastRows;
504 else if (ix + nx == 1)
505 lastRows = mb - row0;
506 else
507 lastRows = mb;
508
509 idx_t lastCols;
510 if (ny == 0)
511 lastCols = 0;
512 else if (iy + ny == this->ny)
513 lastCols = this->lastCols;
514 else if (iy + ny == 1)
515 lastCols = nb - col0;
516 else
517 lastCols = nb;
518
519 return {row0, col0, lastRows, lastCols};
520 }
521
533 std::array<idx_t, 8> _map_to_tiles(idx_t rowStart,
534 idx_t rowEnd,
535 idx_t colStart,
536 idx_t colEnd) const noexcept
537 {
538 const idx_t mb = this->nblockrows();
539 const idx_t nb = this->nblockcols();
540 const idx_t nrows = rowEnd - rowStart;
541 const idx_t ncols = colEnd - colStart;
542
543 assert(rowStart >= 0 && colStart >= 0 &&
544 "Submatrix starts before the beginning of the matrix");
545 assert(rowEnd >= rowStart && colEnd >= colStart &&
546 "Submatrix has negative dimensions");
547 assert(rowEnd <= this->nrows() && colEnd <= this->ncols() &&
548 "Submatrix ends after the end of the matrix");
549
550 const idx_t ix = (rowStart + this->row0) / mb;
551 const idx_t iy = (colStart + this->col0) / nb;
552 const idx_t row0 = (rowStart + this->row0) % mb;
553 const idx_t col0 = (colStart + this->col0) % nb;
554
555 const idx_t nx = (nrows == 0) ? 1 : (row0 + nrows - 1) / mb + 1;
556 const idx_t ny = (ncols == 0) ? 1 : (col0 + ncols - 1) / nb + 1;
557
558 const idx_t lastRows =
559 (nx == 1) ? nrows : (row0 + nrows - 1) % mb + 1;
560 const idx_t lastCols =
561 (ny == 1) ? ncols : (col0 + ncols - 1) % nb + 1;
562
563 return {ix, iy, nx, ny, row0, col0, lastRows, lastCols};
564 }
565 };
566
567} // namespace starpu
568} // namespace tlapack
569
570#endif // TLAPACK_STARPU_MATRIX_HH
Class for representing a matrix in StarPU that is split into tiles.
Definition Matrix.hpp:133
Matrix(T *ptr, idx_t m, idx_t n, idx_t ld, idx_t mt, idx_t nt) noexcept
Create a matrix of size m-by-n from a pointer in main memory.
Definition Matrix.hpp:150
constexpr idx_t get_ny() const noexcept
Get number of tiles in y direction.
Definition Matrix.hpp:217
Matrix< T > map_to_tiles(idx_t rowStart, idx_t rowEnd, idx_t colStart, idx_t colEnd) noexcept
Create a submatrix from starting and ending indices.
Definition Matrix.hpp:318
constexpr Matrix(T *ptr, idx_t m, idx_t n, idx_t mt, idx_t nt) noexcept
Create a matrix of size m-by-n from contiguous data in main memory.
Definition Matrix.hpp:177
constexpr idx_t ncols() const noexcept override
Get the number of columns in the matrix.
Definition Matrix.hpp:278
constexpr idx_t get_nx() const noexcept
Get number of tiles in x direction.
Definition Matrix.hpp:214
Matrix< const T > get_const_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) const noexcept
Create a const submatrix from a list of tiles.
Definition Matrix.hpp:340
Matrix< T > get_tiles(idx_t ix, idx_t iy, idx_t nx, idx_t ny) noexcept
Create a submatrix from a list of tiles.
Definition Matrix.hpp:298
constexpr idx_t nblockcols() const noexcept
Get the maximum number of columns of a tile.
Definition Matrix.hpp:226
constexpr idx_t nrows() const noexcept override
Get the number of rows in the matrix.
Definition Matrix.hpp:265
Matrix< const T > map_to_const_tiles(idx_t rowStart, idx_t rowEnd, idx_t colStart, idx_t colEnd) const noexcept
Create a const submatrix from starting and ending indices.
Definition Matrix.hpp:363
Tile tile(idx_t ix, idx_t iy) noexcept
Get the data handle of a tile in the matrix or the data handle of the matrix if it is not partitioned...
Definition Matrix.hpp:234
friend std::ostream & operator<<(std::ostream &out, const starpu::Matrix< T > &A)
Display matrix in output stream.
Definition Matrix.hpp:384
constexpr idx_t nblockrows() const noexcept
Get the maximum number of rows of a tile.
Definition Matrix.hpp:220
constexpr Matrix(const std::shared_ptr< starpu_data_handle_t > &pHandle, idx_t ix, idx_t iy, idx_t nx, idx_t ny, idx_t row0, idx_t col0, idx_t lastRows, idx_t lastCols) noexcept
Create a submatrix from a handle and a grid.
Definition Matrix.hpp:182
Filters for StarPU data interfaces.
void filter_rows(void *father_interface, void *child_interface, STARPU_ATTRIBUTE_UNUSED struct starpu_data_filter *f, unsigned id, unsigned nparts) noexcept
StarPU filter to partition a matrix along the x (row) dimension.
Definition filters.hpp:105
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Class for representing a tile of a matrix.
Definition Tile.hpp:25
constexpr MatrixEntry< T > operator()(idx_t i, idx_t j) noexcept
Returns a reference to an element of the matrix.
Definition Matrix.hpp:100
constexpr MatrixEntry< T > operator[](idx_t i) noexcept
Returns a reference to an element of the vector.
Definition Matrix.hpp:112
constexpr T operator()(idx_t i, idx_t j) const noexcept
Returns an element of the matrix.
Definition Matrix.hpp:57
constexpr T operator[](idx_t i) const noexcept
Returns an element of the vector.
Definition Matrix.hpp:70
Class for accessing the elements of a tlapack::starpu::Matrix.
Definition Matrix.hpp:39