GTestSuite : Designing test cases for ZGEMM

- Designed test cases for unit testing of ZGEMM compute
  kernel for handling inputs when k == 1. The design
  uses value-parameterized testing for checking accuracy,
  and verifying the mandate in case of exception values
  on the inputs/output.

- The design uses type-parameterized testing for verifying
  BLAS standard for invalid input cases, and also for early
  return scenarios.

- Added the function template set_ev_mat( ... ) as part of
  testinghelpers. This function is used as a helper for
  inducing exception values onto indices specified as
  arguments to the test_gemm( ... ) interface.

- Abstracted the function definition of getValueString( ... )
  from the NRM2 testing interface to testinghelpers(renamed
  as get_value_string( ... ) for naming consistency), in order
  to use it as a helper function across all APIs in case of
  exception value testing.

AMD-Internal: [CPUPL-3823]
Change-Id: I0fea21f9c8759bbbdc88ba0a016202753e28f2a7
This commit is contained in:
Vignesh Balasubramanian
2023-07-31 22:32:12 +05:30
parent e5e9127a68
commit 32104c400c
13 changed files with 857 additions and 38 deletions

View File

@@ -62,17 +62,17 @@ void randomgenerators(int from, int to, T* alpha, char fp);
* if fp=='f' the elements will have random float values.
*/
template<typename T>
void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char fp);
void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char fp = BLIS_ELEMENT_TYPE);
template<typename T>
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, char fp);
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE);
template<typename T>
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, char fp);
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, char fp = BLIS_ELEMENT_TYPE);
template<typename T>
void randomgenerators(int from, int to, char storage, char uplo, gtint_t m,
T* a, gtint_t lda, char fp );
T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE );
} //end of namespace datagenerators
template<typename T>
@@ -92,4 +92,16 @@ std::vector<T> get_vector( gtint_t n, gtint_t incx, T value );
template<typename T>
std::vector<T> get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint_t lda, T value );
template<typename T>
void set_vector( gtint_t n, gtint_t incx, T* x, T value );
template<typename T>
void set_matrix( char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, T value );
// Function template to set the exception value exval on matrix m, at indices (i, j)
// In case of transposition, this function internally swaps the indices, and thus they can be
// passed without swapping on the instantiator.
template<typename T>
void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m );
} //end of namespace testinghelpers

View File

@@ -371,4 +371,13 @@ void print_vector( const char *vec, gtint_t n, T *x, gtint_t incx, const char *s
template<typename T>
void print_matrix( const char *mat, char storage, gtint_t m, gtint_t n, T *a, gtint_t ld, const char *spec );
/**
* @brief returns a string with the correct NaN/Inf for printing
*
* @tparam T float, double, scomplex, dcomplex.
* @param exval exception value for setting the string.
*/
template<typename T>
std::string get_value_string( T exval );
} //end of namespace testinghelpers

View File

@@ -441,6 +441,26 @@ std::vector<T> get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint
return a;
}
template<typename T>
void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m )
{
// Setting the exception values on the indices passed as arguments
if ( storage == 'c' || storage == 'C' )
{
if ( trns == 'n' || trns == 'N' )
m[i + j*ld] = exval;
else
m[j + i*ld] = exval;
}
else
{
if ( trns == 'n' || trns == 'N' )
m[i*ld + j] = exval;
else
m[j*ld + i] = exval;
}
}
} //end of namespace testinghelpers
// Explicit template instantiations
@@ -493,3 +513,18 @@ template std::vector<float> testinghelpers::get_matrix( char, char, gtint_t, gti
template std::vector<double> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, double );
template std::vector<scomplex> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, scomplex );
template std::vector<dcomplex> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, dcomplex );
template void testinghelpers::set_vector<float>( gtint_t, gtint_t, float*, float );
template void testinghelpers::set_vector<double>( gtint_t, gtint_t, double*, double );
template void testinghelpers::set_vector<scomplex>( gtint_t, gtint_t, scomplex*, scomplex );
template void testinghelpers::set_vector<dcomplex>( gtint_t, gtint_t, dcomplex*, dcomplex );
template void testinghelpers::set_matrix<float>( char, gtint_t, gtint_t, float*, char, gtint_t, float );
template void testinghelpers::set_matrix<double>( char, gtint_t, gtint_t, double*, char, gtint_t, double );
template void testinghelpers::set_matrix<scomplex>( char, gtint_t, gtint_t, scomplex*, char, gtint_t, scomplex );
template void testinghelpers::set_matrix<dcomplex>( char, gtint_t, gtint_t, dcomplex*, char, gtint_t, dcomplex );
template void testinghelpers::set_ev_mat<float>( char, char, gtint_t, gtint_t, gtint_t, float, float* );
template void testinghelpers::set_ev_mat<double>( char, char, gtint_t, gtint_t, gtint_t, double, double* );
template void testinghelpers::set_ev_mat<scomplex>( char, char, gtint_t, gtint_t, gtint_t, scomplex, scomplex* );
template void testinghelpers::set_ev_mat<dcomplex>( char, char, gtint_t, gtint_t, gtint_t, dcomplex, dcomplex* );

View File

@@ -614,4 +614,76 @@ template void print_matrix<double>( char, gtint_t, gtint_t, double *, gtint_t, c
template void print_matrix<scomplex>( char, gtint_t, gtint_t, scomplex *, gtint_t, const char * );
template void print_matrix<dcomplex>( char, gtint_t, gtint_t, dcomplex *, gtint_t, const char * );
/*
Helper function that returns a string based on the value that is passed
The return values are as follows :
If datatype is real : "nan", "inf"/"minus_inf", "value", where "value"
is the string version of the value that is passed, if it is not nan/inf/-inf.
If the datatype is complex : The string is concatenated with both the real and
imaginary components values, based on analysis done separately to each of them
(similar to real datatype).
*/
template<typename T>
std::string get_value_string(T exval)
{
std::string exval_str;
if constexpr (testinghelpers::type_info<T>::is_real)
{
if(std::isnan(exval))
exval_str = "nan";
else if(std::isinf(exval))
exval_str = (exval >= 0) ? "inf" : "minus_inf";
else
exval_str = ( exval >= 0) ? std::to_string(int(exval)) : "minus_" + std::to_string(int(std::abs(exval)));
}
else
{
if(std::isnan(exval.real))
{
exval_str = "nan";
if(std::isinf(exval.imag))
exval_str = exval_str + "pi" + ((exval.imag >= 0) ? "inf" : "minus_inf");
else
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
}
else if(std::isnan(exval.imag))
{
if(std::isinf(exval.real))
exval_str = ((exval.real >= 0) ? "inf" : "minus_inf");
else
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
exval_str = exval_str + "pinan";
}
else if(std::isinf(exval.real))
{
exval_str = ((exval.real >= 0) ? "inf" : "minus_inf");
if(std::isnan(exval.imag))
exval_str = exval_str + "pinan";
else
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
}
else if(std::isinf(exval.imag))
{
if(std::isnan(exval.real))
exval_str = "nan";
else
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
exval_str = exval_str + ((exval.imag >= 0) ? "inf" : "minus_inf");
}
else
{
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
}
}
return exval_str;
}
template std::string testinghelpers::get_value_string( float );
template std::string testinghelpers::get_value_string( double );
template std::string testinghelpers::get_value_string( scomplex );
template std::string testinghelpers::get_value_string( dcomplex );
} //end of namespace testinghelpers

View File

@@ -0,0 +1,264 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <gtest/gtest.h>
#include "common/testing_helpers.h"
#include "gemm.h"
#include "inc/check_error.h"
#include "common/wrong_inputs_helpers.h"
template <typename T>
class Gemm_IIT_ERS_Test : public ::testing::Test {};
typedef ::testing::Types<float, double, scomplex, dcomplex> TypeParam; // The supported datatypes from BLAS calls for GEMM
TYPED_TEST_SUITE(Gemm_IIT_ERS_Test, TypeParam); // Defining individual testsuites based on the datatype support.
// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h.
using namespace testinghelpers::IIT;
#ifdef TEST_BLAS
/*
Incorrect Input Testing(IIT)
BLAS exceptions get triggered in the following cases(for GEMM):
1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 1)
2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 2)
3. When m < 0 (info = 3)
4. When n < 0 (info = 4)
5. When k < 0 (info = 5)
6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value
7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value
8. When ldc < max(1, n) (info = 13)
*/
// When info == 1
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transa)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for TRANS value for A.
gemm<T>( STORAGE, 'p', TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 2
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transb)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for TRANS value for B.
gemm<T>( STORAGE, TRANS, 'p', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 3
TYPED_TEST(Gemm_IIT_ERS_Test, m_lt_zero)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for m.
gemm<T>( STORAGE, TRANS, TRANS, -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 4
TYPED_TEST(Gemm_IIT_ERS_Test, n_lt_zero)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for n.
gemm<T>( STORAGE, TRANS, TRANS, M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 5
TYPED_TEST(Gemm_IIT_ERS_Test, k_lt_zero)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for k.
gemm<T>( STORAGE, TRANS, TRANS, M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 8
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_lda)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for lda.
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 10
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldb)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for ldb.
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When info == 13
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldc)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
// Call BLIS Gemm with a invalid value for ldc.
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
/*
Early Return Scenarios(ERS) :
The GEMM API is expected to return early in the following cases:
1. When m == 0.
2. When n == 0.
3. When (alpha == 0 or k == 0) and beta == 1.
*/
// When m is 0
TYPED_TEST(Gemm_IIT_ERS_Test, m_eq_zero)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
gemm<T>( STORAGE, TRANS, TRANS, 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When n is 0
TYPED_TEST(Gemm_IIT_ERS_Test, n_eq_zero)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
gemm<T>( STORAGE, TRANS, TRANS, M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When alpha is 0 and beta is 1
TYPED_TEST(Gemm_IIT_ERS_Test, alpha_zero_beta_one)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
T alpha, beta;
testinghelpers::initzero<T>( alpha );
testinghelpers::initone<T>( beta );
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
// When k is 0 and beta is 1
TYPED_TEST(Gemm_IIT_ERS_Test, k_zero_beta_one)
{
using T = TypeParam;
// Defining the C matrix with values for debugging purposes
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
// Copy so that we check that the elements of C are not modified.
std::vector<T> c_ref(c);
T alpha, beta;
testinghelpers::initone<T>( alpha );
testinghelpers::initone<T>( beta );
gemm<T>( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC );
// Use bitwise comparison (no threshold).
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
}
#endif

View File

@@ -76,4 +76,63 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n,
// check component-wise error.
//----------------------------------------------------------
computediff<T>( storage, m, n, c.data(), c_ref.data(), ldc, thresh );
}
// Test body used for exception value testing, by iducing an exception value
// in the index that is passed for each of the matrices.
/*
(ai, aj) is the index with corresponding exception value aexval in matrix A.
The index is with respect to the assumption that the matrix is column stored,
without any transpose. In case of the row-storage and/or transpose, the index
is translated from its assumption accordingly.
Ex : (2, 3) with storage 'c' and transpose 'n' becomes (3, 2) if storage becomes
'r' or transpose becomes 't'.
*/
// (bi, bj) is the index with corresponding exception value bexval in matrix B.
// (ci, cj) is the index with corresponding exception value cexval in matrix C.
template<typename T>
void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n,
gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha,
T beta, gtint_t ai, gtint_t aj, T aexval, gtint_t bi, gtint_t bj, T bexval,
gtint_t ci, gtint_t cj, T cexval, double thresh )
{
// Compute the leading dimensions of a, b, and c.
gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc );
gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc );
gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc );
//----------------------------------------------------------
// Initialize matrics with random numbers
//----------------------------------------------------------
std::vector<T> a = testinghelpers::get_random_matrix<T>( -2, 8, storage, trnsa, m, k, lda );
std::vector<T> b = testinghelpers::get_random_matrix<T>( -5, 2, storage, trnsb, k, n, ldb );
std::vector<T> c = testinghelpers::get_random_matrix<T>( -3, 5, storage, 'n', m, n, ldc );
// Inducing exception values onto the matrices based on the indices passed as arguments.
// Assumption is that the indices are with respect to the matrices in column storage without
// any transpose. In case of difference in storage scheme or transposition, the row and column
// indices are appropriately swapped.
testinghelpers::set_ev_mat( storage, trnsa, lda, ai, aj, aexval, a.data() );
testinghelpers::set_ev_mat( storage, trnsb, ldb, bi, bj, bexval, b.data() );
testinghelpers::set_ev_mat( storage, 'n', ldc, ci, cj, cexval, c.data() );
// Create a copy of c so that we can check reference results.
std::vector<T> c_ref(c);
//----------------------------------------------------------
// Call BLIS function
//----------------------------------------------------------
gemm<T>( storage, trnsa, trnsb, m, n, k, &alpha, a.data(), lda,
b.data(), ldb, &beta, c.data(), ldc );
//----------------------------------------------------------
// Call reference implementation.
//----------------------------------------------------------
testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha,
a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc );
//----------------------------------------------------------
// check component-wise error.
//----------------------------------------------------------
computediff<T>( storage, m, n, c.data(), c_ref.data(), ldc, thresh, true );
}

View File

@@ -0,0 +1,356 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
The following file contains both the exception value testing(EVT) and the
positive accuracy testing of the bli_zgemm_4x4_avx2_k1_nn( ... ) computational
kernel. This kernel is invoked from the BLAS layer, and inputs are given
in a manner so as to avoid the other code-paths and test only the required
kernel.
*/
#include <gtest/gtest.h>
#include "test_gemm.h"
class ZGemmEVTTest :
public ::testing::TestWithParam<std::tuple<char,
char,
char,
gtint_t,
gtint_t,
gtint_t,
gtint_t,
gtint_t,
dcomplex,
gtint_t,
gtint_t,
dcomplex,
gtint_t,
gtint_t,
dcomplex,
dcomplex,
dcomplex,
gtint_t,
gtint_t,
gtint_t>> {};
TEST_P(ZGemmEVTTest, Unit_Tester)
{
using T = dcomplex;
//----------------------------------------------------------
// Initialize values from the parameters passed through
// test suite instantiation (INSTANTIATE_TEST_SUITE_P).
//----------------------------------------------------------
// matrix storage format(row major, column major)
char storage = std::get<0>(GetParam());
// denotes whether matrix a is n,c,t,h
char transa = std::get<1>(GetParam());
// denotes whether matrix b is n,c,t,h
char transb = std::get<2>(GetParam());
// matrix size m
gtint_t m = std::get<3>(GetParam());
// matrix size n
gtint_t n = std::get<4>(GetParam());
// matrix size k
gtint_t k = std::get<5>(GetParam());
gtint_t ai, aj, bi, bj, ci, cj;
T aex, bex, cex;
ai = std::get<6>(GetParam());
aj = std::get<7>(GetParam());
aex = std::get<8>(GetParam());
bi = std::get<9>(GetParam());
bj = std::get<10>(GetParam());
bex = std::get<11>(GetParam());
ci = std::get<12>(GetParam());
cj = std::get<13>(GetParam());
cex = std::get<14>(GetParam());
// specifies alpha value
T alpha = std::get<15>(GetParam());
// specifies beta value
T beta = std::get<16>(GetParam());
// lda, ldb, ldc increments.
// If increments are zero, then the array size matches the matrix size.
// If increments are nonnegative, the array size is bigger than the matrix size.
gtint_t lda_inc = std::get<17>(GetParam());
gtint_t ldb_inc = std::get<18>(GetParam());
gtint_t ldc_inc = std::get<19>(GetParam());
// Set the threshold for the errors:
double thresh = 10*m*n*testinghelpers::getEpsilon<T>();
//----------------------------------------------------------
// Call test body using these parameters
//----------------------------------------------------------
test_gemm<T>( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc,
alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh );
}
// Helper classes for printing the test case parameters based on the instantiator
// These are mainly used to help with debugging, in case of failures
// Utility to print the test-case in case of exception value on matrices
class ZGemmEVMatPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, gtint_t, gtint_t, dcomplex,
gtint_t, gtint_t, dcomplex, gtint_t, gtint_t, dcomplex, dcomplex, dcomplex,
gtint_t, gtint_t, gtint_t>> str) const {
char sfm = std::get<0>(str.param);
char tsa = std::get<1>(str.param);
char tsb = std::get<2>(str.param);
gtint_t m = std::get<3>(str.param);
gtint_t n = std::get<4>(str.param);
gtint_t k = std::get<5>(str.param);
gtint_t ai, aj, bi, bj, ci, cj;
dcomplex aex, bex, cex;
ai = std::get<6>(str.param);
aj = std::get<7>(str.param);
aex = std::get<8>(str.param);
bi = std::get<9>(str.param);
bj = std::get<10>(str.param);
bex = std::get<11>(str.param);
ci = std::get<12>(str.param);
cj = std::get<13>(str.param);
cex = std::get<14>(str.param);
dcomplex alpha = std::get<15>(str.param);
dcomplex beta = std::get<16>(str.param);
gtint_t lda_inc = std::get<17>(str.param);
gtint_t ldb_inc = std::get<18>(str.param);
gtint_t ldc_inc = std::get<19>(str.param);
#ifdef TEST_BLAS
std::string str_name = "zgemm_";
#elif TEST_CBLAS
std::string str_name = "cblas_zgemm";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "blis_zgemm";
#endif
str_name = str_name + "_" + sfm+sfm+sfm;
str_name = str_name + "_" + tsa + tsb;
str_name = str_name + "_" + std::to_string(m);
str_name = str_name + "_" + std::to_string(n);
str_name = str_name + "_" + std::to_string(k);
str_name = str_name + "_A" + std::to_string(ai) + std::to_string(aj);
str_name = str_name + "_" + testinghelpers::get_value_string(aex);
str_name = str_name + "_B" + std::to_string(bi) + std::to_string(bj);
str_name = str_name + "_" + testinghelpers::get_value_string(bex);
str_name = str_name + "_C" + std::to_string(ci) + std::to_string(cj);
str_name = str_name + "_" + testinghelpers::get_value_string(cex);
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);
str_name = str_name + "_" + std::to_string(lda_inc);
str_name = str_name + "_" + std::to_string(ldb_inc);
str_name = str_name + "_" + std::to_string(ldc_inc);
return str_name;
}
};
// Utility to print the test-case in case of exception value on matrices
class ZGemmEVAlphaBetaPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, gtint_t, gtint_t, dcomplex,
gtint_t, gtint_t, dcomplex, gtint_t, gtint_t, dcomplex, dcomplex, dcomplex,
gtint_t, gtint_t, gtint_t>> str) const {
char sfm = std::get<0>(str.param);
char tsa = std::get<1>(str.param);
char tsb = std::get<2>(str.param);
gtint_t m = std::get<3>(str.param);
gtint_t n = std::get<4>(str.param);
gtint_t k = std::get<5>(str.param);
dcomplex alpha = std::get<15>(str.param);
dcomplex beta = std::get<16>(str.param);
gtint_t lda_inc = std::get<17>(str.param);
gtint_t ldb_inc = std::get<18>(str.param);
gtint_t ldc_inc = std::get<19>(str.param);
#ifdef TEST_BLAS
std::string str_name = "zgemm_";
#elif TEST_CBLAS
std::string str_name = "cblas_zgemm";
#else //#elif TEST_BLIS_TYPED
std::string str_name = "blis_zgemm";
#endif
str_name = str_name + "_" + sfm+sfm+sfm;
str_name = str_name + "_" + tsa + tsb;
str_name = str_name + "_" + std::to_string(m);
str_name = str_name + "_" + std::to_string(n);
str_name = str_name + "_" + std::to_string(k);
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);
str_name = str_name + "_" + std::to_string(lda_inc);
str_name = str_name + "_" + std::to_string(ldb_inc);
str_name = str_name + "_" + std::to_string(ldc_inc);
return str_name;
}
};
static double NaN = std::numeric_limits<double>::quiet_NaN();
static double Inf = std::numeric_limits<double>::infinity();
// Exception value testing(on matrices)
/*
For the bli_zgemm_4x4_avx2_k1_nn kernel, the main and fringe dimensions are as follows:
For m : Main = { 4 }, fringe = { 2, 1 }
For n : Main = { 4 }, fringe = { 2, 1 }
Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched
separately, since if m/n is 1, the inputs are redirected to ZGEMV.
*/
// Testing for the main loop case for m and n
// The kernel uses 2 loads and 4 broadcasts. The exception values
// are induced at one index individually for each of the loads.
// They are also induced in the broadcast direction at two places.
INSTANTIATE_TEST_SUITE_P(
bli_zgemm_4x4_avx2_k1_nn_evt_mat_main,
ZGemmEVTTest,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('n'), // transa
::testing::Values('n'), // transb
::testing::Values(gtint_t(4)), // m
::testing::Values(gtint_t(4)), // n
::testing::Values(gtint_t(1)), // k
::testing::Values(gtint_t(1), gtint_t(3)), // ai
::testing::Values(gtint_t(0)), // aj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval
::testing::Values(gtint_t(0)), // bi
::testing::Values(gtint_t(0), gtint_t(2)), // bj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval
::testing::Values(gtint_t(0), gtint_t(2)), // ci
::testing::Values(gtint_t(1), gtint_t(3)), // cj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval
::testing::Values(dcomplex{-2.2, 3.3}), // alpha
::testing::Values(dcomplex{1.2, -2.3}), // beta
::testing::Values(gtint_t(0)), // increment to the leading dim of a
::testing::Values(gtint_t(0)), // increment to the leading dim of b
::testing::Values(gtint_t(0)) // increment to the leading dim of c
),
::ZGemmEVMatPrint()
);
// Testing the fringe cases
// Fringe case minimum size is 2 along both m and n.
// Invloves only one load(AVX2 or (AVX2+SSE)). Thus,
// the exception values are induced at the first and second indices of the
// column vector A and row vector B.
INSTANTIATE_TEST_SUITE_P(
bli_zgemm_4x4_avx2_k1_nn_evt_mat_fringe,
ZGemmEVTTest,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('n'), // transa
::testing::Values('n'), // transb
::testing::Values(gtint_t(2), gtint_t(3)), // m
::testing::Values(gtint_t(2), gtint_t(3)), // n
::testing::Values(gtint_t(1)), // k
::testing::Values(gtint_t(0), gtint_t(1)), // ai
::testing::Values(gtint_t(0)), // aj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval
::testing::Values(gtint_t(0)), // bi
::testing::Values(gtint_t(0), gtint_t(1)), // bj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval
::testing::Values(gtint_t(0), gtint_t(1)), // ci
::testing::Values(gtint_t(0), gtint_t(1)), // cj
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval
::testing::Values(dcomplex{-2.2, 3.3}), // alpha
::testing::Values(dcomplex{1.2, -2.3}), // beta
::testing::Values(gtint_t(0)), // increment to the leading dim of a
::testing::Values(gtint_t(0)), // increment to the leading dim of b
::testing::Values(gtint_t(0)) // increment to the leading dim of c
),
::ZGemmEVMatPrint()
);
// Exception value testing(on alpha and beta)
// Alpha and beta are set to exception values
INSTANTIATE_TEST_SUITE_P(
bli_zgemm_4x4_avx2_k1_nn_evt_alphabeta,
ZGemmEVTTest,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('n'), // transa
::testing::Values('n'), // transb
::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // m
::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // n
::testing::Values(gtint_t(1)), // k
::testing::Values(gtint_t(0)), // ai
::testing::Values(gtint_t(0)), // aj
::testing::Values(dcomplex{0.0, 0.0}),
::testing::Values(gtint_t(0)), // bi
::testing::Values(gtint_t(0)), // bj
::testing::Values(dcomplex{0.0, 0.0}),
::testing::Values(gtint_t(0)), // ci
::testing::Values(gtint_t(0)), // cj
::testing::Values(dcomplex{0.0, 0.0}),
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // alpha
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // beta
::testing::Values(gtint_t(0)), // increment to the leading dim of a
::testing::Values(gtint_t(0)), // increment to the leading dim of b
::testing::Values(gtint_t(0)) // increment to the leading dim of c
),
::ZGemmEVAlphaBetaPrint()
);

View File

@@ -35,7 +35,7 @@
#include <gtest/gtest.h>
#include "test_gemm.h"
class ZGemmTest :
class ZGemmAccTest :
public ::testing::TestWithParam<std::tuple<char,
char,
char,
@@ -48,7 +48,7 @@ class ZGemmTest :
gtint_t,
gtint_t>> {};
TEST_P(ZGemmTest, RandomData)
TEST_P(ZGemmAccTest, Unit_Tester)
{
using T = dcomplex;
//----------------------------------------------------------
@@ -87,7 +87,7 @@ TEST_P(ZGemmTest, RandomData)
test_gemm<T>( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh );
}
class ZGemmTestPrint {
class ZGemmAccPrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, dcomplex, dcomplex, gtint_t, gtint_t, gtint_t>> str) const {
@@ -114,12 +114,8 @@ public:
str_name = str_name + "_" + std::to_string(m);
str_name = str_name + "_" + std::to_string(n);
str_name = str_name + "_" + std::to_string(k);
std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real))));
alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag)))));
std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real))));
beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag)))));
str_name = str_name + "_a" + alpha_str;
str_name = str_name + "_b" + beta_str;
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);;
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);;
str_name = str_name + "_" + std::to_string(lda_inc);
str_name = str_name + "_" + std::to_string(ldb_inc);
str_name = str_name + "_" + std::to_string(ldc_inc);
@@ -127,10 +123,41 @@ public:
}
};
// Unit testing for bli_zgemm_4x4_avx2_k1_nn kernel
/* From the BLAS layer(post parameter checking), the inputs will be redirected to this kernel
if m != 1, n !=1 and k == 1 */
INSTANTIATE_TEST_SUITE_P(
bli_zgemm_4x4_avx2_k1_nn,
ZGemmAccTest,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
,'r'
#endif
), // storage format
::testing::Values('n'), // transa
::testing::Values('n'), // transb
::testing::Range(gtint_t(2), gtint_t(8), 1), // m
::testing::Range(gtint_t(2), gtint_t(8), 1), // n
::testing::Values(gtint_t(1)), // k
::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0},
dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9},
dcomplex{0.0, 0.0}), // alpha
::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0},
dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9},
dcomplex{0.0, 0.0}), // beta
::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b
::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c
),
::ZGemmAccPrint()
);
// Black box testing.
INSTANTIATE_TEST_SUITE_P(
Blackbox,
ZGemmTest,
ZGemmAccTest,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS
@@ -148,5 +175,5 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b
::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c
),
::ZGemmTestPrint()
::ZGemmAccPrint()
);

View File

@@ -88,10 +88,10 @@ public:
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
str_name = str_name + "_" + incx_str;
str_name = str_name + "_i" + std::to_string(i);
std::string iexval_str = getValueString(iexval);
std::string iexval_str = testinghelpers::get_value_string(iexval);
str_name = str_name + "_" + iexval_str;
str_name = str_name + "_j" + std::to_string(j);
std::string jexval_str = getValueString(jexval);
std::string jexval_str = testinghelpers::get_value_string(jexval);
str_name = str_name + "_" + jexval_str;
return str_name;
}

View File

@@ -88,10 +88,10 @@ public:
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
str_name = str_name + "_" + incx_str;
str_name = str_name + "_i" + std::to_string(i);
std::string iexval_str = "_Re_" + getValueString(iexval.real) + "_Im_" + getValueString(iexval.imag);
std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag);
str_name = str_name + iexval_str;
str_name = str_name + "_j" + std::to_string(j);
std::string jexval_str = "_Re_" + getValueString(jexval.real) + "_Im_" + getValueString(jexval.imag);
std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag);
str_name = str_name + jexval_str;
return str_name;
}

View File

@@ -88,10 +88,10 @@ public:
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
str_name = str_name + "_" + incx_str;
str_name = str_name + "_i" + std::to_string(i);
std::string iexval_str = "_Re_" + getValueString(iexval.real) + "_Im_" + getValueString(iexval.imag);
std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag);
str_name = str_name + iexval_str;
str_name = str_name + "_j" + std::to_string(j);
std::string jexval_str = "_Re_" + getValueString(jexval.real) + "_Im_" + getValueString(jexval.imag);
std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag);
str_name = str_name + jexval_str;
return str_name;
}

View File

@@ -88,10 +88,10 @@ public:
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
str_name = str_name + "_" + incx_str;
str_name = str_name + "_i" + std::to_string(i);
std::string iexval_str = getValueString(iexval);
std::string iexval_str = testinghelpers::get_value_string(iexval);
str_name = str_name + "_" + iexval_str;
str_name = str_name + "_j" + std::to_string(j);
std::string jexval_str = getValueString(jexval);
std::string jexval_str = testinghelpers::get_value_string(jexval);
str_name = str_name + "_" + jexval_str;
return str_name;
}

View File

@@ -98,19 +98,4 @@ void test_nrm2( gtint_t n, gtint_t incx, gtint_t i, T iexval, gtint_t j = 0, T j
//----------------------------------------------------------
// Compare using NaN/Inf checks.
computediff<RT>( norm, norm_ref, true );
}
// Helper function that returns a string with the correct NaN/Inf printing
// so that we can print the test names correctly from using parametrized testing.
template<typename T>
std::string getValueString(T exval)
{
std::string exval_str;
if(std::isnan(exval))
exval_str = "nan";
else if(std::isinf(exval))
exval_str = (exval > 0) ? "inf" : "minus_inf";
else
exval_str = ( exval > 0) ? std::to_string(int(exval)) : "minus_" + std::to_string(int(std::abs(exval)));
return exval_str;
}