<T>LAPACK 0.1.1
C++ Template Linear Algebra PACKage
Loading...
Searching...
No Matches
testutils.hpp
Go to the documentation of this file.
1
6//
7// Copyright (c) 2021-2023, University of Colorado Denver. All rights reserved.
8//
9// This file is part of <T>LAPACK.
10// <T>LAPACK is free software: you can redistribute it and/or modify it under
11// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
12
13#ifndef TLAPACK_TESTUTILS_HH
14#define TLAPACK_TESTUTILS_HH
15
16// Definitions
17#include "testdefinitions.hpp"
18
19// Matrix market
20#include "MatrixMarket.hpp"
21
22// Plugin for debug
24
25// <T>LAPACK
27#include <tlapack/blas/gemm.hpp>
28#include <tlapack/blas/herk.hpp>
33
34#ifndef TLAPACK_BUILD_STANDALONE_TESTS
35 #include <catch2/catch_template_test_macros.hpp>
36 #include <catch2/generators/catch_generators.hpp>
37
39 #define SKIP_TEST return
40#else
41 #include <iostream>
42 #include <tuple>
43
45 #define SKIP_TEST return 0
46
47 // Get first argument of a variadic macro
48 #define GET_FIRST_ARG(arg1, ...) arg1
49
50 // Below, it is a solution found in
51 // https://stackoverflow.com/a/62984543/5253097
52 #define DEPAREN(X) ESC(ISH X)
53 #define ISH(...) ISH __VA_ARGS__
54 #define ESC(...) ESC_(__VA_ARGS__)
55 #define ESC_(...) VAN##__VA_ARGS__
56 #define VANISH
57
58namespace tlapack {
59namespace catch2 {
60
61 std::string return_scanf(const char*);
62 std::string return_scanf(std::string);
63
64 template <class T>
65 T return_scanf(T)
66 {
67 if constexpr (std::is_integral<T>::value) {
68 T arg;
69 std::cin >> arg;
70 return arg;
71 }
72 else if constexpr (std::is_enum<T>::value) {
73 char c;
74 std::cin >> c;
75 return T(c);
76 }
77
78 std::abort();
79 std::cout << "Include more cases here!\n";
80 return {};
81 }
82 template <class T, class U>
83 pair<T, U> return_scanf(pair<T, U>)
84 {
85 const T x = return_scanf(T());
86 const U y = return_scanf(U());
87 return pair<T, U>(x, y);
88 }
89 template <class... Ts>
90 std::tuple<Ts...> return_scanf(std::tuple<Ts...>)
91 {
92 std::tuple<Ts...> t;
93 constexpr size_t N = std::tuple_size<std::tuple<Ts...>>::value;
94 if constexpr (N > 0)
95 std::get<0>(t) = return_scanf(std::get<0>(std::tuple<Ts...>()));
96 if constexpr (N > 1)
97 std::get<1>(t) = return_scanf(std::get<1>(std::tuple<Ts...>()));
98 if constexpr (N > 2)
99 std::get<2>(t) = return_scanf(std::get<2>(std::tuple<Ts...>()));
100 if constexpr (N > 3)
101 std::get<3>(t) = return_scanf(std::get<3>(std::tuple<Ts...>()));
102 if constexpr (N > 4)
103 std::get<4>(t) = return_scanf(std::get<4>(std::tuple<Ts...>()));
104 if constexpr (N > 5)
105 std::get<5>(t) = return_scanf(std::get<5>(std::tuple<Ts...>()));
106 if constexpr (N > 6) {
107 std::abort();
108 std::cout << "Include more cases here!\n";
109 }
110 return t;
111 }
112} // namespace catch2
113} // namespace tlapack
114
115 #define TEMPLATE_TEST_CASE(TITLE, TAGS, ...) \
116 using TestType = DEPAREN(GET_FIRST_ARG(__VA_ARGS__)); \
117 int main(const int argc, const char* argv[])
118
119 #define GENERATE(...) \
120 tlapack::catch2::return_scanf(GET_FIRST_ARG(__VA_ARGS__))
121
122 #define DYNAMIC_SECTION(...) std::cout << __VA_ARGS__ << std::endl;
123
124 #define REQUIRE(cond) \
125 std::cout << #cond << ": " \
126 << (static_cast<bool>(cond) ? "true" : "false") << std::endl
127
128 #define CHECK(cond) \
129 std::cout << #cond << ": " \
130 << (static_cast<bool>(cond) ? "true" : "false") << std::endl
131
132 #define INFO(...) std::cout << __VA_ARGS__ << std::endl;
133 #define UNSCOPED_INFO(...) std::cout << __VA_ARGS__ << std::endl;
134#endif
135
136namespace tlapack {
137
148template <TLAPACK_MATRIX matrix_t>
150{
151 using idx_t = size_type<matrix_t>;
152 using T = type_t<matrix_t>;
153 using real_t = real_type<T>;
154
155 const idx_t m = nrows(Q);
156 const idx_t n = ncols(Q);
157
158 tlapack_check(nrows(res) == ncols(res));
159 tlapack_check(nrows(res) == min(m, n));
160
161 // res = I
162 laset(UPPER_TRIANGLE, (T)0.0, (T)1.0, res);
163 if (n <= m) {
164 // res = Q'Q - I
166 }
167 else {
168 // res = QQ' - I
169 herk(UPPER_TRIANGLE, NO_TRANS, (real_t)1.0, Q, (real_t)-1.0, res);
170 }
171
172 // Compute ||res||_F
174}
175
184template <TLAPACK_MATRIX matrix_t>
186{
187 using T = type_t<matrix_t>;
188 using idx_t = size_type<matrix_t>;
189
190 // Functor
192
193 const idx_t m = min(nrows(Q), ncols(Q));
194
195 std::vector<T> res_;
196 auto res = new_matrix(res_, m, m);
197 return check_orthogonality(Q, res);
198}
199
212template <TLAPACK_MATRIX matrix_t>
215{
216 using T = type_t<matrix_t>;
217 using real_t = real_type<T>;
218
219 tlapack_check(nrows(A) == ncols(A));
220 tlapack_check(nrows(Q) == ncols(Q));
221 tlapack_check(nrows(B) == ncols(B));
222 tlapack_check(nrows(res) == ncols(res));
223 tlapack_check(nrows(work) == ncols(work));
224 tlapack_check(nrows(A) == nrows(Q));
225 tlapack_check(nrows(A) == nrows(B));
226 tlapack_check(nrows(A) == nrows(res));
227 tlapack_check(nrows(A) == nrows(work));
228
229 // res = Q'*A*Q - B
230 lacpy(GENERAL, B, res);
231 gemm(CONJ_TRANS, NO_TRANS, (real_t)1.0, Q, A, work);
232 gemm(NO_TRANS, NO_TRANS, (real_t)1.0, work, Q, (real_t)-1.0, res);
233
234 // Compute ||res||_F
235 return lange(FROB_NORM, res);
236}
237
248template <TLAPACK_MATRIX matrix_t>
250 matrix_t& Q,
251 matrix_t& B)
252{
253 using T = type_t<matrix_t>;
254 using idx_t = size_type<matrix_t>;
255
256 // Functor
258
259 const idx_t n = ncols(A);
260
261 std::vector<T> res_;
262 auto res = new_matrix(res_, n, n);
263 std::vector<T> work_;
264 auto work = new_matrix(work_, n, n);
265
267}
268
282template <TLAPACK_MATRIX matrix_t>
284 matrix_t& A,
285 matrix_t& Q,
286 matrix_t& Z,
287 matrix_t& B,
288 matrix_t& res,
289 matrix_t& work)
290{
291 using T = type_t<matrix_t>;
292 using real_t = real_type<T>;
293
294 tlapack_check(nrows(A) == ncols(A));
295 tlapack_check(nrows(Q) == ncols(Q));
296 tlapack_check(nrows(Z) == ncols(Z));
297 tlapack_check(nrows(B) == ncols(B));
298 tlapack_check(nrows(res) == ncols(res));
299 tlapack_check(nrows(work) == ncols(work));
300 tlapack_check(nrows(A) == nrows(Q));
301 tlapack_check(nrows(A) == nrows(Z));
302 tlapack_check(nrows(A) == nrows(B));
303 tlapack_check(nrows(A) == nrows(res));
304 tlapack_check(nrows(A) == nrows(work));
305
306 // res = Q'*A*Q - B
307 lacpy(GENERAL, B, res);
308 gemm(CONJ_TRANS, NO_TRANS, (real_t)1.0, Q, A, work);
309 gemm(NO_TRANS, NO_TRANS, (real_t)1.0, work, Z, (real_t)-1.0, res);
310
311 // Compute ||res||_F
312 return lange(FROB_NORM, res);
313}
314
326template <TLAPACK_MATRIX matrix_t>
328 matrix_t& Q,
329 matrix_t& Z,
330 matrix_t& B)
331{
332 using T = type_t<matrix_t>;
333 using idx_t = size_type<matrix_t>;
334
335 // Functor
337
338 const idx_t n = ncols(A);
339
340 std::vector<T> res_;
341 auto res = new_matrix(res_, n, n);
342 std::vector<T> work_;
343 auto work = new_matrix(work_, n, n);
344
345 return check_similarity_transform(A, Q, Z, B, res, work);
346}
347
348//
349// GDB doesn't handle templates well, so we explicitly define some versions of
350// the functions for common template arguments
351//
352void print_matrix_r(const LegacyMatrix<float, size_t, Layout::ColMajor>& A);
353void print_matrix_d(const LegacyMatrix<double, size_t, Layout::ColMajor>& A);
354void print_matrix_c(
355 const LegacyMatrix<std::complex<float>, size_t, Layout::ColMajor>& A);
356void print_matrix_z(
357 const LegacyMatrix<std::complex<double>, size_t, Layout::ColMajor>& A);
358void print_rowmajormatrix_r(
360void print_rowmajormatrix_d(
362void print_rowmajormatrix_c(
363 const LegacyMatrix<std::complex<float>, size_t, Layout::RowMajor>& A);
364void print_rowmajormatrix_z(
365 const LegacyMatrix<std::complex<double>, size_t, Layout::RowMajor>& A);
366
367//
368// GDB doesn't handle templates well, so we explicitly define some versions of
369// the functions for common template arguments
370//
371std::string visualize_matrix_r(
373std::string visualize_matrix_d(
375std::string visualize_matrix_c(
376 const LegacyMatrix<std::complex<float>, size_t, Layout::ColMajor>& A);
377std::string visualize_matrix_z(
378 const LegacyMatrix<std::complex<double>, size_t, Layout::ColMajor>& A);
379std::string visualize_rowmajormatrix_r(
381std::string visualize_rowmajormatrix_d(
383std::string visualize_rowmajormatrix_c(
384 const LegacyMatrix<std::complex<float>, size_t, Layout::RowMajor>& A);
385std::string visualize_rowmajormatrix_z(
386 const LegacyMatrix<std::complex<double>, size_t, Layout::RowMajor>& A);
387
388} // namespace tlapack
389
390#endif // TLAPACK_TESTUTILS_HH
MaxtrixMarket class and random generators.
constexpr internal::FrobNorm FROB_NORM
Frobenius norm of matrices.
Definition types.hpp:342
constexpr internal::UpperTriangle UPPER_TRIANGLE
Upper Triangle access.
Definition types.hpp:181
constexpr internal::GeneralAccess GENERAL
General access.
Definition types.hpp:175
constexpr internal::ConjTranspose CONJ_TRANS
conjugate transpose
Definition types.hpp:259
constexpr internal::NoTranspose NO_TRANS
no transpose
Definition types.hpp:255
auto lange(norm_t normType, const matrix_t &A)
Calculates the norm of a matrix.
Definition lange.hpp:38
void laset(uplo_t uplo, const type_t< matrix_t > &alpha, const type_t< matrix_t > &beta, matrix_t &A)
Initializes a matrix to diagonal and off-diagonal values.
Definition laset.hpp:38
real_type< type_t< matrix_t > > check_generalized_similarity_transform(matrix_t &A, matrix_t &Q, matrix_t &Z, matrix_t &B, matrix_t &res, matrix_t &work)
Calculates res = Q'*A*Z - B and the frobenius norm of res.
Definition testutils.hpp:283
real_type< type_t< matrix_t > > check_orthogonality(matrix_t &Q, matrix_t &res)
Calculates res = Q'*Q - I if m <= n or res = Q*Q' otherwise Also computes the frobenius norm of res.
Definition testutils.hpp:149
void lacpy(uplo_t uplo, const matrixA_t &A, matrixB_t &B)
Copies a matrix from A to B.
Definition lacpy.hpp:38
auto lanhe(norm_t normType, uplo_t uplo, const matrix_t &A)
Calculates the norm of a hermitian matrix.
Definition lanhe.hpp:43
real_type< type_t< matrix_t > > check_similarity_transform(matrix_t &A, matrix_t &Q, matrix_t &B, matrix_t &res, matrix_t &work)
Calculates res = Q'*A*Q - B and the frobenius norm of res.
Definition testutils.hpp:213
void herk(Uplo uplo, Op trans, const alpha_t &alpha, const matrixA_t &A, const beta_t &beta, matrixC_t &C)
Hermitian rank-k update:
Definition herk.hpp:68
void gemm(Op transA, Op transB, const alpha_t &alpha, const matrixA_t &A, const matrixB_t &B, const beta_t &beta, matrixC_t &C)
General matrix-matrix multiply:
Definition gemm.hpp:61
#define tlapack_check(cond)
Throw an error if cond is false.
Definition exceptionHandling.hpp:98
Concept for matrices that can be converted to a legacy matrix.
typename traits::real_type_traits< Types..., int >::type real_type
The common real type of the list of types.
Definition scalar_type_traits.hpp:113
Definitions for the unit tests.