GTestSuite: Added Tests for [C\Z]TRSM

- Added API tests for [C\Z]TRSM.
  - Added Extreme Value Test cases (EVT) for [C\Z]TRSM.
    - Tests for various combinations of INFs
       and NANs in A and B matrix are added.
  - Added Invalid input test cases (IIT).
  - Added micro kernel testing for ZTRSM
    - Added unit tests for small and native
      path kernels.
 - Added memory testing for ZTRSM
   kernels.

AMD-Internal: [CPUPL-4641]
Change-Id: I0db6b2c75b59821e1cde33532fb13400fab43412
This commit is contained in:
Shubham Sharma
2024-02-27 15:36:09 +05:30
parent 9968821ed9
commit 01b2af0af3
7 changed files with 1048 additions and 74 deletions

View File

@@ -43,7 +43,7 @@
template <typename T>
class TRSM_IIT_ERS_Test : public ::testing::Test {};
typedef ::testing::Types<float, double> TypeParam;
typedef ::testing::Types<float, double, scomplex, dcomplex> TypeParam;
TYPED_TEST_SUITE(TRSM_IIT_ERS_Test, TypeParam);
@@ -52,7 +52,7 @@ TYPED_TEST_SUITE(TRSM_IIT_ERS_Test, TypeParam);
using namespace testinghelpers::IIT;
/**
* @brief Test s/d trsm when side argument is incorrect
* @brief Test TRSM when side argument is incorrect
* when info == 1
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_side)
@@ -67,7 +67,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_side)
}
/**
* @brief Test s/d trsm when UPLO argument is incorrect
* @brief Test TRSM when UPLO argument is incorrect
* when info == 2
*
*/
@@ -83,7 +83,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_UPLO)
}
/**
* @brief Test s/d trsm when TRANS argument is incorrect
* @brief Test TRSM when TRANS argument is incorrect
* when info == 3
*
*/
@@ -99,7 +99,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_TRANS)
}
/**
* @brief Test s/d trsm when DIAG argument is incorrect
* @brief Test TRSM when DIAG argument is incorrect
* when info == 4
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_DIAG)
@@ -114,7 +114,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_DIAG)
}
/**
* @brief Test s/d trsm when m is negative
* @brief Test TRSM when m is negative
* when info == 5
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_m)
@@ -129,7 +129,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_m)
}
/**
* @brief Test s/d trsm when n is negative
* @brief Test TRSM when n is negative
* when info == 6
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_n)
@@ -144,7 +144,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_n)
}
/**
* @brief Test s/d trsm when lda is incorrect
* @brief Test TRSM when lda is incorrect
* when info == 9
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_lda)
@@ -159,7 +159,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_lda)
}
/**
* @brief Test s/d trsm when ldb is incorrect
* @brief Test TRSM when ldb is incorrect
* when info == 11
*/
TYPED_TEST(TRSM_IIT_ERS_Test, invalid_ldb)
@@ -185,7 +185,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, invalid_ldb)
*/
/**
* @brief Test s/d trsm when m is zero
* @brief Test TRSM when M is zero
*/
TYPED_TEST(TRSM_IIT_ERS_Test, m_eq_zero)
{
@@ -199,7 +199,7 @@ TYPED_TEST(TRSM_IIT_ERS_Test, m_eq_zero)
}
/**
* @brief Test s/d trsm when m is zero
* @brief Test TRSM when N is zero
*/
TYPED_TEST(TRSM_IIT_ERS_Test, n_eq_zero)
{

View File

@@ -0,0 +1,203 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <gtest/gtest.h>
#include "test_trsm.h"
class ctrsmEVT :
public ::testing::TestWithParam<std::tuple<char, // storage format
char, // side
char, // uplo
char, // transa
char, // diaga
gtint_t, // m
gtint_t, // n
scomplex, // alpha
gtint_t, // lda_inc
gtint_t, // ldb_inc
EVT_TYPE, // EVT test for A
EVT_TYPE>> {}; // EVT test for B
TEST_P(ctrsmEVT, NaNInfCheck)
{
using T = scomplex;
//----------------------------------------------------------
// Initialize values from the parameters passed through
// test suite instantiation (INSTANTIATE_TEST_SUITE_P).
//----------------------------------------------------------
// matrix storage format(row major, column major)
char storage = std::get<0>(GetParam());
// specifies matrix A appears left or right in
// the matrix multiplication
char side = std::get<1>(GetParam());
// specifies upper or lower triangular part of A is used
char uploa = std::get<2>(GetParam());
// denotes whether matrix a is n,c,t,h
char transa = std::get<3>(GetParam());
// denotes whether matrix a in unit or non-unit diagonal
char diaga = std::get<4>(GetParam());
// matrix size m
gtint_t m = std::get<5>(GetParam());
// matrix size n
gtint_t n = std::get<6>(GetParam());
// specifies alpha value
T alpha = std::get<7>(GetParam());
// lda, ldb, ldc increments.
// If increments are zero, then the array size matches the matrix size.
// If increments are nonnegative, the array size is bigger than the matrix size.
gtint_t lda_inc = std::get<8>(GetParam());
gtint_t ldb_inc = std::get<9>(GetParam());
EVT_TYPE a_init = std::get<10>(GetParam());
EVT_TYPE b_init = std::get<11>(GetParam());
// Set the threshold for the errors:
double thresh = std::max(m, n)*testinghelpers::getEpsilon<T>();
//----------------------------------------------------------
// Call test body using these parameters
//----------------------------------------------------------
test_trsm<T>( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init );
}
class ctrsmEVTPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, char, char, gtint_t, gtint_t, scomplex, gtint_t, gtint_t, EVT_TYPE, EVT_TYPE>> str) const {
char sfm = std::get<0>(str.param);
char side = std::get<1>(str.param);
char uploa = std::get<2>(str.param);
char transa = std::get<3>(str.param);
char diaga = std::get<4>(str.param);
gtint_t m = std::get<5>(str.param);
gtint_t n = std::get<6>(str.param);
scomplex alpha = std::get<7>(str.param);
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
EVT_TYPE a_encode = std::get<10>(str.param);
EVT_TYPE b_encode = std::get<11>(str.param);
#ifdef TEST_BLAS
std::string str_name = "blas_";
#elif TEST_CBLAS
std::string str_name = "cblas_";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "blis_";
#endif
str_name = str_name + "_stor_" + sfm;
str_name = str_name + "_side_" + side;
str_name = str_name + "_uploa_" + uploa;
str_name = str_name + "_transa_" + transa;
str_name = str_name + "_diag_" + diaga;
str_name = str_name + "_m_" + std::to_string(m);
str_name = str_name + "_n_" + std::to_string(n);
std::string alpha_str = testinghelpers::get_value_string(alpha);
str_name = str_name + "_alpha_" + alpha_str;
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
str_name = str_name + "_lda_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, transa, mn, mn, lda_inc ));
str_name = str_name + "_ldb_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, 'n', m, n, ldb_inc ));
str_name = str_name + "_a_evt_" + std::to_string(a_encode);
str_name = str_name + "_b_evt_" + std::to_string(b_encode);
return str_name;
}
};
/**
* @brief Test CTRSM for extreme values
* Code paths taken for:
* TRSV -> 1
* AVX2 Small -> 301, 324
* Native -> 1051, 1176
*/
INSTANTIATE_TEST_SUITE_P(
evt,
ctrsmEVT,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n','c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 301, 1051), // m
::testing::Values(1, 324, 1176), // n
::testing::Values(scomplex{-2.4, 2.0},
scomplex{-0.0, 2.3},
scomplex{-2.4, 0.0},
scomplex{ 0.0, 0.0}), // alpha
::testing::Values(gtint_t(0)), // increment to the leading dim of a
::testing::Values(gtint_t(0)), // increment to the leading dim of b
::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF,
NEG_INF, NEG_NaN), // EVT test for A
::testing::Values(NO_EVT, NaN, INF, NaN_INF, NEG_INF, NEG_NaN) // EVT test for B
),
::ctrsmEVTPrint()
);
/**
* @brief Test CTRSM with differnt values of alpha
* code paths covered:
* TRSV -> 1
* TRSM_AVX2_small -> 3
* TRSM_NATIVE -> 1001
*/
INSTANTIATE_TEST_SUITE_P(
Alpha,
ctrsmEVT,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 3, 1001), // n
::testing::Values(1, 3, 1001), // m
::testing::Values(scomplex{NAN, -2.0},
scomplex{-2.0, NAN},
scomplex{INFINITY, 3.1f},
scomplex{NAN, -INFINITY}), // alpha
::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b
::testing::Values(NO_EVT), // EVT test for A
::testing::Values(NO_EVT) // EVT test for B
),
::ctrsmEVTPrint()
);

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -35,19 +35,19 @@
#include <gtest/gtest.h>
#include "test_trsm.h"
class ctrsmTest :
public ::testing::TestWithParam<std::tuple<char,
char,
char,
char,
char,
gtint_t,
gtint_t,
scomplex,
gtint_t,
gtint_t>> {};
class ctrsmAPI :
public ::testing::TestWithParam<std::tuple<char, // storage format
char, // side
char, // uplo
char, // transa
char, // diaga
gtint_t, // m
gtint_t, // n
scomplex, // alpha
gtint_t, // lda_inc
gtint_t>> {}; // ldb_inc
TEST_P(ctrsmTest, RandomData)
TEST_P(ctrsmAPI, FunctionalTest)
{
using T = scomplex;
//----------------------------------------------------------
@@ -78,7 +78,7 @@ TEST_P(ctrsmTest, RandomData)
gtint_t ldb_inc = std::get<9>(GetParam());
// Set the threshold for the errors:
double thresh = (std::max)(m, n)*testinghelpers::getEpsilon<T>();
double thresh = 1.5*(std::max)(m, n)*testinghelpers::getEpsilon<T>();
//----------------------------------------------------------
// Call test body using these parameters
@@ -86,7 +86,7 @@ TEST_P(ctrsmTest, RandomData)
test_trsm<T>( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh );
}
class ctrsmTestPrint {
class ctrsmPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, char, char, gtint_t, gtint_t, scomplex, gtint_t, gtint_t>> str) const {
@@ -101,30 +101,38 @@ public:
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
#ifdef TEST_BLAS
std::string str_name = "ctrsm_";
std::string str_name = "blas_";
#elif TEST_CBLAS
std::string str_name = "cblas_ctrsm";
std::string str_name = "cblas_";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "bli_ctrsm";
std::string str_name = "bli_";
#endif
str_name = str_name + "_" + sfm+sfm+sfm;
str_name = str_name + "_" + side + uploa + transa;
str_name = str_name + "_d" + diaga;
str_name = str_name + "_" + std::to_string(m);
str_name = str_name + "_" + std::to_string(n);
std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real))));
alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag)))));
str_name = str_name + "_a" + alpha_str;
str_name = str_name + "_" + std::to_string(lda_inc);
str_name = str_name + "_" + std::to_string(ldb_inc);
str_name = str_name + "_stor_" + sfm;
str_name = str_name + "_side_" + side;
str_name = str_name + "_uploa_" + uploa;
str_name = str_name + "_transa_" + transa;
str_name = str_name + "_diag_" + diaga;
str_name = str_name + "_m_" + std::to_string(m);
str_name = str_name + "_n_" + std::to_string(n);
std::string alpha_str = testinghelpers::get_value_string(alpha);
str_name = str_name + "_alpha_" + alpha_str;
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
str_name = str_name + "_lda_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, transa, mn, mn, lda_inc ));
str_name = str_name + "_ldb_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, 'n', m, n, ldb_inc ));
return str_name;
}
};
// Black box testing.
/**
* @brief Test CTRSM native path, which starts from size 1001 for BLAS api
* and starts from size 0 for BLIS api.
*/
INSTANTIATE_TEST_SUITE_P(
Blackbox,
ctrsmTest,
Native,
ctrsmAPI,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
@@ -135,11 +143,81 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n','c','t'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Range(gtint_t(10), gtint_t(31), 10), // m
::testing::Range(gtint_t(10), gtint_t(31), 10), // n
::testing::Values(1, 112, 1200), // m
::testing::Values(1, 154, 1317), // n
::testing::Values(scomplex{2.0,-1.0}), // alpha
::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b
::testing::Values(gtint_t(31)), // increment to the leading dim of a
::testing::Values(gtint_t(45)) // increment to the leading dim of b
),
::ctrsmTestPrint()
::ctrsmPrint()
);
/**
* @brief Test CTRSM small avx2 path all fringe cases
* Kernel size for avx2 small path is 8x3, testing in range of
* 1 to 8 ensures all finge cases are being tested.
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX2_fringe,
ctrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Range(gtint_t(1), gtint_t(9), 1), // m
::testing::Range(gtint_t(1), gtint_t(9), 1), // n
::testing::Values(scomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(58)), // increment to the leading dim of a
::testing::Values(gtint_t(32)) // increment to the leading dim of b
),
::ctrsmPrint()
);
/**
* @brief Test CTRSM small avx2 path, this code path is used in range 0 to 1000
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX2,
ctrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(17, 1000), // m
::testing::Values(48, 1000), // n
::testing::Values(scomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(85)), // increment to the leading dim of a
::testing::Values(gtint_t(33)) // increment to the leading dim of b
),
::ctrsmPrint()
);
/**
* @brief Test CTRSM with differnt values of alpha
* code paths covered:
* TRSV -> 1
* TRSM_AVX2_small -> 3
* TRSM_NATIVE -> 1001
*/
INSTANTIATE_TEST_SUITE_P(
Alpha,
ctrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 3, 1001), // n
::testing::Values(1, 3, 1001), // m
::testing::Values(scomplex{2.0, 0.0}, scomplex{0.0, -10.0},
scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}), // alpha
::testing::Values(gtint_t(0), gtint_t(45)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(93)) // increment to the leading dim of b
),
::ctrsmPrint()
);

View File

@@ -0,0 +1,203 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <gtest/gtest.h>
#include "test_trsm.h"
class ztrsmEVT :
public ::testing::TestWithParam<std::tuple<char, // storage format
char, // side
char, // uplo
char, // transa
char, // diaga
gtint_t, // m
gtint_t, // n
dcomplex, // alpha
gtint_t, // lda_inc
gtint_t, // ldb_inc
EVT_TYPE, // EVT test for A
EVT_TYPE>> {}; // EVT test for B
TEST_P(ztrsmEVT, NaNInfCheck)
{
using T = dcomplex;
//----------------------------------------------------------
// Initialize values from the parameters passed through
// test suite instantiation (INSTANTIATE_TEST_SUITE_P).
//----------------------------------------------------------
// matrix storage format(row major, column major)
char storage = std::get<0>(GetParam());
// specifies matrix A appears left or right in
// the matrix multiplication
char side = std::get<1>(GetParam());
// specifies upper or lower triangular part of A is used
char uploa = std::get<2>(GetParam());
// denotes whether matrix a is n,c,t,h
char transa = std::get<3>(GetParam());
// denotes whether matrix a in unit or non-unit diagonal
char diaga = std::get<4>(GetParam());
// matrix size m
gtint_t m = std::get<5>(GetParam());
// matrix size n
gtint_t n = std::get<6>(GetParam());
// specifies alpha value
T alpha = std::get<7>(GetParam());
// lda, ldb, ldc increments.
// If increments are zero, then the array size matches the matrix size.
// If increments are nonnegative, the array size is bigger than the matrix size.
gtint_t lda_inc = std::get<8>(GetParam());
gtint_t ldb_inc = std::get<9>(GetParam());
EVT_TYPE a_init = std::get<10>(GetParam());
EVT_TYPE b_init = std::get<11>(GetParam());
// Set the threshold for the errors:
double thresh = std::max(m, n)*testinghelpers::getEpsilon<T>();
//----------------------------------------------------------
// Call test body using these parameters
//----------------------------------------------------------
test_trsm<T>( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init );
}
class ztrsmEVTPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, char, char, gtint_t, gtint_t, dcomplex, gtint_t, gtint_t, EVT_TYPE, EVT_TYPE>> str) const {
char sfm = std::get<0>(str.param);
char side = std::get<1>(str.param);
char uploa = std::get<2>(str.param);
char transa = std::get<3>(str.param);
char diaga = std::get<4>(str.param);
gtint_t m = std::get<5>(str.param);
gtint_t n = std::get<6>(str.param);
dcomplex alpha = std::get<7>(str.param);
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
EVT_TYPE a_encode = std::get<10>(str.param);
EVT_TYPE b_encode = std::get<11>(str.param);
#ifdef TEST_BLAS
std::string str_name = "blas_";
#elif TEST_CBLAS
std::string str_name = "cblas_";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "blis_";
#endif
str_name = str_name + "_stor_" + sfm;
str_name = str_name + "_side_" + side;
str_name = str_name + "_uploa_" + uploa;
str_name = str_name + "_transa_" + transa;
str_name = str_name + "_diag_" + diaga;
str_name = str_name + "_m_" + std::to_string(m);
str_name = str_name + "_n_" + std::to_string(n);
std::string alpha_str = testinghelpers::get_value_string(alpha);
str_name = str_name + "_alpha_" + alpha_str;
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
str_name = str_name + "_lda_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, transa, mn, mn, lda_inc ));
str_name = str_name + "_ldb_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, 'n', m, n, ldb_inc ));
str_name = str_name + "_a_evt_" + std::to_string(a_encode);
str_name = str_name + "_b_evt_" + std::to_string(b_encode);
return str_name;
}
};
/**
* @brief Test ZTRSM for extreme values
* Code paths taken for:
* TRSV -> 1
* AVX2 Small -> 151, 82
* Native -> 503, 512
*/
INSTANTIATE_TEST_SUITE_P(
evt,
ztrsmEVT,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n','c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 151, 503), // m
::testing::Values(1, 82, 512), // n
::testing::Values(dcomplex{-2.4, 2.0},
dcomplex{-0.0, 2.3},
dcomplex{-2.4, 0.0},
dcomplex{ 0.0, 0.0}), // alpha
::testing::Values(gtint_t(0)), // increment to the leading dim of a
::testing::Values(gtint_t(0)), // increment to the leading dim of b
::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF,
NEG_INF, NEG_NaN), // EVT test for A
::testing::Values(NO_EVT, NaN, INF, NaN_INF, NEG_INF, NEG_NaN) // EVT test for B
),
::ztrsmEVTPrint()
);
/**
* @brief Test ZTRSM with differnt values of alpha
* code paths covered:
* TRSV -> 1
* TRSM_AVX2_small -> 3
* TRSM_NATIVE -> 501
*/
INSTANTIATE_TEST_SUITE_P(
Alpha,
ztrsmEVT,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 3, 501), // n
::testing::Values(1, 3, 501), // m
::testing::Values(dcomplex{NAN, -2.0},
dcomplex{-2.0, NAN},
dcomplex{INFINITY, 3.1f},
dcomplex{NAN, -INFINITY}), // alpha
::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b
::testing::Values(NO_EVT), // EVT test for A
::testing::Values(NO_EVT) // EVT test for B
),
::ztrsmEVTPrint()
);

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -35,7 +35,7 @@
#include <gtest/gtest.h>
#include "test_trsm.h"
class ztrsmTest :
class ztrsmAPI :
public ::testing::TestWithParam<std::tuple<char,
char,
char,
@@ -47,7 +47,7 @@ class ztrsmTest :
gtint_t,
gtint_t>> {};
TEST_P(ztrsmTest, RandomData)
TEST_P(ztrsmAPI, FunctionalTest)
{
using T = dcomplex;
//----------------------------------------------------------
@@ -78,7 +78,7 @@ TEST_P(ztrsmTest, RandomData)
gtint_t ldb_inc = std::get<9>(GetParam());
// Set the threshold for the errors:
double thresh = (std::max)(m, n)*testinghelpers::getEpsilon<T>();
double thresh = 1.5*(std::max)(m, n)*testinghelpers::getEpsilon<T>();
//----------------------------------------------------------
// Call test body using these parameters
@@ -86,7 +86,7 @@ TEST_P(ztrsmTest, RandomData)
test_trsm<T>( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh );
}
class ztrsmTestPrint {
class ztrsmPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, char, char, gtint_t, gtint_t, dcomplex, gtint_t, gtint_t>> str) const {
@@ -101,30 +101,38 @@ public:
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
#ifdef TEST_BLAS
std::string str_name = "ztrsm_";
std::string str_name = "blas_";
#elif TEST_CBLAS
std::string str_name = "cblas_ztrsm";
std::string str_name = "cblas_";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "bli_ztrsm";
std::string str_name = "bli_";
#endif
str_name = str_name + "_" + sfm+sfm+sfm;
str_name = str_name + "_" + side + uploa + transa;
str_name = str_name + "_d" + diaga;
str_name = str_name + "_" + std::to_string(m);
str_name = str_name + "_" + std::to_string(n);
std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real))));
alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag)))));
str_name = str_name + "_a" + alpha_str;
str_name = str_name + "_" + std::to_string(lda_inc);
str_name = str_name + "_" + std::to_string(ldb_inc);
str_name = str_name + "_stor_" + sfm;
str_name = str_name + "_side_" + side;
str_name = str_name + "_uploa_" + uploa;
str_name = str_name + "_transa_" + transa;
str_name = str_name + "_diag_" + diaga;
str_name = str_name + "_m_" + std::to_string(m);
str_name = str_name + "_n_" + std::to_string(n);
std::string alpha_str = testinghelpers::get_value_string(alpha);
str_name = str_name + "_alpha_" + alpha_str;
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
str_name = str_name + "_lda_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, transa, mn, mn, lda_inc ));
str_name = str_name + "_ldb_" +
std::to_string(testinghelpers::get_leading_dimension( sfm, 'n', m, n, ldb_inc ));
return str_name;
}
};
// Black box testing.
/**
* @brief Test ZTRSM native path, which starts from size 501 for BLAS api
* and starts from size 0 for BLIS api.
*/
INSTANTIATE_TEST_SUITE_P(
Blackbox,
ztrsmTest,
Native,
ztrsmAPI,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
@@ -135,11 +143,81 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n','c','t'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Range(gtint_t(10), gtint_t(11), 10), // m
::testing::Range(gtint_t(10), gtint_t(11), 10), // n
::testing::Values(dcomplex{1.0,2.0}), // alpha
::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b
::testing::Values(1, 53, 520), // m
::testing::Values(1, 38, 511), // n
::testing::Values(dcomplex{2.0,-1.0}), // alpha
::testing::Values(gtint_t(20)), // increment to the leading dim of a
::testing::Values(gtint_t(33)) // increment to the leading dim of b
),
::ztrsmTestPrint()
::ztrsmPrint()
);
/**
* @brief Test ZTRSM small avx2 path all fringe cases
* Kernel size for avx2 small path is 4x3, testing in range of
* 1 to 4 ensures all finge cases are being tested.
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX2_fringe,
ztrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Range(gtint_t(1), gtint_t(5), 1), // m
::testing::Range(gtint_t(1), gtint_t(5), 1), // n
::testing::Values(dcomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(56)), // increment to the leading dim of a
::testing::Values(gtint_t(33)) // increment to the leading dim of b
),
::ztrsmPrint()
);
/**
* @brief Test ZTRSM small avx2 path, this code path is used in range 0 to 500
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX2,
ztrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(17, 500), // m
::testing::Values(48, 500), // n
::testing::Values(dcomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(54)), // increment to the leading dim of a
::testing::Values(gtint_t(37)) // increment to the leading dim of b
),
::ztrsmPrint()
);
/**
* @brief Test ZTRSM with differnt values of alpha
* code paths covered:
* TRSV -> 1
* TRSM_AVX2_small -> 3
* TRSM_NATIVE -> 501
*/
INSTANTIATE_TEST_SUITE_P(
Alpha,
ztrsmAPI,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 'c', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(1, 3, 501), // n
::testing::Values(1, 3, 501), // m
::testing::Values(dcomplex{2.0, 0.0}, dcomplex{0.0, -10.0},
dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}), // alpha
::testing::Values(gtint_t(0), gtint_t(65)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(23)) // increment to the leading dim of b
),
::ztrsmPrint()
);

View File

@@ -0,0 +1,133 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <gtest/gtest.h>
#include "common/testing_helpers.h"
#include "level3/ref_gemm.h"
#include "test_trsm_ukr.h"
#include "level3/trsm/test_trsm.h"
class ctrsmUkrSmall :
public ::testing::TestWithParam<std::tuple< trsm_small_ker_ft, // Function pointer type for CTRSM kernels
char, // side
char, // uploa
char, // diaga
char, // transa
gtint_t, // m
gtint_t, // n
scomplex, // alpha
gtint_t, // lda_inc
gtint_t, // ldb_inc
bool >> {}; // is_memory_test
TEST_P(ctrsmUkrSmall, AccuracyCheck)
{
using T = scomplex;
trsm_small_ker_ft ukr_fp = std::get<0>(GetParam());
char side = std::get<1>(GetParam());
char uploa = std::get<2>(GetParam());
char diaga = std::get<3>(GetParam());
char transa = std::get<4>(GetParam());
gtint_t m = std::get<5>(GetParam());
gtint_t n = std::get<6>(GetParam());
T alpha = std::get<7>(GetParam());
gtint_t lda = std::get<8>(GetParam());
gtint_t ldb = std::get<9>(GetParam());
bool is_memory_test = std::get<10>(GetParam());
double thresh = 2 * std::max(std::max(m, n), 3) * testinghelpers::getEpsilon<T>();
test_trsm_small_ukr<T, trsm_small_ker_ft>( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_SCOMPLEX);
}
class ctrsmSmallUKRPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<trsm_small_ker_ft, char, char, char, char, gtint_t,
gtint_t, scomplex, gtint_t, gtint_t, bool>> str) const{
char side = std::get<1>(str.param);
char uploa = std::get<2>(str.param);
char diaga = std::get<3>(str.param);
char transa = std::get<4>(str.param);
gtint_t m = std::get<5>(str.param);
gtint_t n = std::get<6>(str.param);
scomplex alpha = std::get<7>(str.param);
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
bool is_memory_test = std::get<10>(str.param);
std::string res =
std::string("_side_") + side
+ "_diag_" + diaga
+ "_uplo_" + uploa
+ "_trana_" + transa
+ "_alpha_" + (alpha.real > 0 ? std::to_string(int(alpha.real)) :
std::string("m") + std::to_string(int(alpha.real*-1)))
+ "pi" + (alpha.imag > 0 ? std::to_string(int(alpha.imag)) :
std::string("m") + std::to_string(int(alpha.imag*-1)));
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
res += "_lda_" + std::to_string( lda_inc + mn);
res += "_ldb_" + std::to_string( ldb_inc + m)
+ "_m_" + std::to_string(m)
+ "_n_" + std::to_string(n);
res += is_memory_test ? "_mem_test_enabled" : "_mem_test_disabled";
return res;
}
};
#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3)
INSTANTIATE_TEST_SUITE_P (
bli_trsm_small,
ctrsmUkrSmall,
::testing::Combine(
::testing::Values(bli_trsm_small), // ker_ptr
::testing::Values('l', 'r'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 'c', 't'), // transa
::testing::Range(gtint_t(1), gtint_t(9), 1), // m
::testing::Range(gtint_t(1), gtint_t(9), 1), // n
::testing::Values(scomplex{-1.4, 3.2},
scomplex{ 2.8, -0.5},
scomplex{-1.4, 0.0},
scomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 10, 194), // lda_inc
::testing::Values(0, 10, 194), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
::ctrsmSmallUKRPrint()
);
#endif

View File

@@ -0,0 +1,279 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <gtest/gtest.h>
#include "common/testing_helpers.h"
#include "level3/ref_gemm.h"
#include "test_trsm_ukr.h"
#include "level3/trsm/test_trsm.h"
class ztrsmUkrNat :
public ::testing::TestWithParam<std::tuple< zgemmtrsm_ukr_ft, // Function pointer type for ZTRSM kernels
char, // storage
char, // uploa
char, // diaga
gtint_t, // m
gtint_t, // n
gtint_t, // k
dcomplex, // alpha
gtint_t, // ldc_inc
bool >> {}; // is_memory_test
class ztrsmUkrSmall :
public ::testing::TestWithParam<std::tuple< trsm_small_ker_ft, // Function pointer type for ZTRSM kernels
char, // side
char, // uploa
char, // diaga
char, // transa
gtint_t, // m
gtint_t, // n
dcomplex, // alpha
gtint_t, // lda_inc
gtint_t, // ldb_inc
bool >> {}; // is_memory_test
TEST_P(ztrsmUkrNat, AccuracyCheck)
{
using T = dcomplex;
zgemmtrsm_ukr_ft ukr_fp = std::get<0>(GetParam());
char storage = std::get<1>(GetParam());
char uploa = std::get<2>(GetParam());
char diaga = std::get<3>(GetParam());
gtint_t m = std::get<4>(GetParam());
gtint_t n = std::get<5>(GetParam());
gtint_t k = std::get<6>(GetParam());
T alpha = std::get<7>(GetParam());
gtint_t ldc = std::get<8>(GetParam());
bool is_memory_test = std::get<9>(GetParam());
double thresh = 2 * std::max(std::max(m, n), 3) * testinghelpers::getEpsilon<T>();
test_trsm_ukr<T, zgemmtrsm_ukr_ft>( ukr_fp, storage, uploa, diaga, m, n, k, alpha, ldc, thresh, is_memory_test);
}
TEST_P(ztrsmUkrSmall, AccuracyCheck)
{
using T = dcomplex;
trsm_small_ker_ft ukr_fp = std::get<0>(GetParam());
char side = std::get<1>(GetParam());
char uploa = std::get<2>(GetParam());
char diaga = std::get<3>(GetParam());
char transa = std::get<4>(GetParam());
gtint_t m = std::get<5>(GetParam());
gtint_t n = std::get<6>(GetParam());
T alpha = std::get<7>(GetParam());
gtint_t lda = std::get<8>(GetParam());
gtint_t ldb = std::get<9>(GetParam());
bool is_memory_test = std::get<10>(GetParam());
double thresh = 2 * std::max(std::max(m, n), 3) * testinghelpers::getEpsilon<T>();
test_trsm_small_ukr<T, trsm_small_ker_ft>( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_DCOMPLEX);
}
class ztrsmUkrNatPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<zgemmtrsm_ukr_ft, char, char, char, gtint_t,
gtint_t, gtint_t, dcomplex, gtint_t, bool>> str) const{
char storage = std::get<1>(str.param);
char uploa = std::get<2>(str.param);
char diaga = std::get<3>(str.param);
gtint_t m = std::get<4>(str.param);
gtint_t n = std::get<5>(str.param);
gtint_t k = std::get<6>(str.param);
dcomplex alpha = std::get<7>(str.param);
gtint_t ldc = std::get<8>(str.param);
bool is_memory_test = std::get<9>(str.param);
std::string res =
std::string("stor_") + storage
+ "_diag_" + diaga
+ "_uplo_" + uploa
+ "_k_" + std::to_string(k)
+ "_alpha_" + (alpha.real > 0 ? std::to_string(int(alpha.real)) :
std::string("m") + std::to_string(int(alpha.real*-1)))
+ "pi" + (alpha.imag > 0 ? std::to_string(int(alpha.imag)) :
std::string("m") + std::to_string(int(alpha.imag*-1)));
ldc += (storage == 'r' || storage == 'R') ? n : m;
res += "_ldc_" + std::to_string(ldc);
res += is_memory_test ? "_mem_test_enabled" : "_mem_test_disabled";
return res;
}
};
class ztrsmUkrSmallPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<trsm_small_ker_ft, char, char, char, char, gtint_t,
gtint_t, dcomplex, gtint_t, gtint_t, bool>> str) const{
char side = std::get<1>(str.param);
char uploa = std::get<2>(str.param);
char diaga = std::get<3>(str.param);
char transa = std::get<4>(str.param);
gtint_t m = std::get<5>(str.param);
gtint_t n = std::get<6>(str.param);
dcomplex alpha = std::get<7>(str.param);
gtint_t lda_inc = std::get<8>(str.param);
gtint_t ldb_inc = std::get<9>(str.param);
bool is_memory_test = std::get<10>(str.param);
std::string res =
std::string("side_") + side
+ "_diag_" + diaga
+ "_uplo_" + uploa
+ "_trana_" + transa
+ "_alpha_" + (alpha.real > 0 ? std::to_string(int(alpha.real)) :
std::string("m") + std::to_string(int(alpha.real*-1)))
+ "pi" + (alpha.imag > 0 ? std::to_string(int(alpha.imag)) :
std::string("m") + std::to_string(int(alpha.imag*-1)));
gtint_t mn;
testinghelpers::set_dim_with_side( side, m, n, &mn );
res += "_lda_" + std::to_string( lda_inc + mn);
res += "_ldb_" + std::to_string( ldb_inc + m)
+ "_m_" + std::to_string(m)
+ "_n_" + std::to_string(n);
res += is_memory_test ? "_mem_test_enabled" : "_mem_test_disabled";
return res;
}
};
#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512)
INSTANTIATE_TEST_SUITE_P (
bli_zgemmtrsm_l_zen4_asm_4x12,
ztrsmUkrNat,
::testing::Combine(
::testing::Values(bli_zgemmtrsm_l_zen4_asm_4x12), // ker_ptr
::testing::Values('c', 'r', 'g'), // stor
::testing::Values('l'), // uplo
::testing::Values('u', 'n'), // diaga
::testing::Values(4), // m
::testing::Values(12), // n
::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 9, 53), // ldc
::testing::Values(false, true) // is_memory_test
),
::ztrsmUkrNatPrint()
);
INSTANTIATE_TEST_SUITE_P (
bli_zgemmtrsm_u_zen4_asm_4x12,
ztrsmUkrNat,
::testing::Combine(
::testing::Values(bli_zgemmtrsm_u_zen4_asm_4x12), // ker_ptr
::testing::Values('c', 'r', 'g'), // stor
::testing::Values('u'), // uplo
::testing::Values('u', 'n'), // diaga
::testing::Values(4), // m
::testing::Values(12), // n
::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 9, 53), // ldc
::testing::Values(false, true) // is_memory_test
),
::ztrsmUkrNatPrint()
);
#endif
#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3)
INSTANTIATE_TEST_SUITE_P (
bli_zgemmtrsm_l_zen_asm_2x6,
ztrsmUkrNat,
::testing::Combine(
::testing::Values(bli_zgemmtrsm_l_zen_asm_2x6), // ker_ptr
::testing::Values('c', 'r', 'g'), // stor
::testing::Values('l'), // uplo
::testing::Values('u', 'n'), // diaga
::testing::Values(2), // m
::testing::Values(6), // n
::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 9, 53), // ldc
::testing::Values(false, true) // is_memory_test
),
::ztrsmUkrNatPrint()
);
INSTANTIATE_TEST_SUITE_P (
bli_zgemmtrsm_u_zen_asm_2x6,
ztrsmUkrNat,
::testing::Combine(
::testing::Values(bli_zgemmtrsm_u_zen_asm_2x6), // ker_ptr
::testing::Values('c', 'r', 'g'), // stor
::testing::Values('u'), // uplo
::testing::Values('u', 'n'), // diaga
::testing::Values(2), // m
::testing::Values(6), // n
::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 9, 53), // ldc
::testing::Values(false, true) // is_memory_test
),
::ztrsmUkrNatPrint()
);
INSTANTIATE_TEST_SUITE_P (
bli_trsm_small,
ztrsmUkrSmall,
::testing::Combine(
::testing::Values(bli_trsm_small), // ker_ptr
::testing::Values('l', 'r'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 'c', 't'), // transa
::testing::Range(gtint_t(1), gtint_t(5), 1), // m
::testing::Range(gtint_t(1), gtint_t(5), 1), // n
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 10, 194), // lda_inc
::testing::Values(0, 10, 194), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
::ztrsmUkrSmallPrint()
);
#endif