mirror of
https://github.com/amd/blis.git
synced 2026-05-12 18:15:37 +00:00
GTestSuite : Designing test cases for ZGEMM
- Designed test cases for unit testing of ZGEMM compute kernel for handling inputs when k == 1. The design uses value-parameterized testing for checking accuracy, and verifying the mandate in case of exception values on the inputs/output. - The design uses type-parameterized testing for verifying BLAS standard for invalid input cases, and also for early return scenarios. - Added the function template set_ev_mat( ... ) as part of testinghelpers. This function is used as a helper for inducing exception values onto indices specified as arguments to the test_gemm( ... ) interface. - Abstracted the function definition of getValueString( ... ) from the NRM2 testing interface to testinghelpers(renamed as get_value_string( ... ) for naming consistency), in order to use it as a helper function across all APIs in case of exception value testing. AMD-Internal: [CPUPL-3823] Change-Id: I0fea21f9c8759bbbdc88ba0a016202753e28f2a7
This commit is contained in:
@@ -62,17 +62,17 @@ void randomgenerators(int from, int to, T* alpha, char fp);
|
||||
* if fp=='f' the elements will have random float values.
|
||||
*/
|
||||
template<typename T>
|
||||
void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char fp);
|
||||
void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char fp = BLIS_ELEMENT_TYPE);
|
||||
|
||||
template<typename T>
|
||||
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, char fp);
|
||||
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE);
|
||||
|
||||
template<typename T>
|
||||
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, char fp);
|
||||
void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, char fp = BLIS_ELEMENT_TYPE);
|
||||
|
||||
template<typename T>
|
||||
void randomgenerators(int from, int to, char storage, char uplo, gtint_t m,
|
||||
T* a, gtint_t lda, char fp );
|
||||
T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE );
|
||||
} //end of namespace datagenerators
|
||||
|
||||
template<typename T>
|
||||
@@ -92,4 +92,16 @@ std::vector<T> get_vector( gtint_t n, gtint_t incx, T value );
|
||||
template<typename T>
|
||||
std::vector<T> get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint_t lda, T value );
|
||||
|
||||
template<typename T>
|
||||
void set_vector( gtint_t n, gtint_t incx, T* x, T value );
|
||||
|
||||
template<typename T>
|
||||
void set_matrix( char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, T value );
|
||||
|
||||
// Function template to set the exception value exval on matrix m, at indices (i, j)
|
||||
// In case of transposition, this function internally swaps the indices, and thus they can be
|
||||
// passed without swapping on the instantiator.
|
||||
template<typename T>
|
||||
void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m );
|
||||
|
||||
} //end of namespace testinghelpers
|
||||
|
||||
@@ -371,4 +371,13 @@ void print_vector( const char *vec, gtint_t n, T *x, gtint_t incx, const char *s
|
||||
template<typename T>
|
||||
void print_matrix( const char *mat, char storage, gtint_t m, gtint_t n, T *a, gtint_t ld, const char *spec );
|
||||
|
||||
/**
|
||||
* @brief returns a string with the correct NaN/Inf for printing
|
||||
*
|
||||
* @tparam T float, double, scomplex, dcomplex.
|
||||
* @param exval exception value for setting the string.
|
||||
*/
|
||||
template<typename T>
|
||||
std::string get_value_string( T exval );
|
||||
|
||||
} //end of namespace testinghelpers
|
||||
|
||||
@@ -441,6 +441,26 @@ std::vector<T> get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m )
|
||||
{
|
||||
// Setting the exception values on the indices passed as arguments
|
||||
if ( storage == 'c' || storage == 'C' )
|
||||
{
|
||||
if ( trns == 'n' || trns == 'N' )
|
||||
m[i + j*ld] = exval;
|
||||
else
|
||||
m[j + i*ld] = exval;
|
||||
}
|
||||
else
|
||||
{
|
||||
if ( trns == 'n' || trns == 'N' )
|
||||
m[i*ld + j] = exval;
|
||||
else
|
||||
m[j*ld + i] = exval;
|
||||
}
|
||||
}
|
||||
|
||||
} //end of namespace testinghelpers
|
||||
|
||||
// Explicit template instantiations
|
||||
@@ -493,3 +513,18 @@ template std::vector<float> testinghelpers::get_matrix( char, char, gtint_t, gti
|
||||
template std::vector<double> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, double );
|
||||
template std::vector<scomplex> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, scomplex );
|
||||
template std::vector<dcomplex> testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, dcomplex );
|
||||
|
||||
template void testinghelpers::set_vector<float>( gtint_t, gtint_t, float*, float );
|
||||
template void testinghelpers::set_vector<double>( gtint_t, gtint_t, double*, double );
|
||||
template void testinghelpers::set_vector<scomplex>( gtint_t, gtint_t, scomplex*, scomplex );
|
||||
template void testinghelpers::set_vector<dcomplex>( gtint_t, gtint_t, dcomplex*, dcomplex );
|
||||
|
||||
template void testinghelpers::set_matrix<float>( char, gtint_t, gtint_t, float*, char, gtint_t, float );
|
||||
template void testinghelpers::set_matrix<double>( char, gtint_t, gtint_t, double*, char, gtint_t, double );
|
||||
template void testinghelpers::set_matrix<scomplex>( char, gtint_t, gtint_t, scomplex*, char, gtint_t, scomplex );
|
||||
template void testinghelpers::set_matrix<dcomplex>( char, gtint_t, gtint_t, dcomplex*, char, gtint_t, dcomplex );
|
||||
|
||||
template void testinghelpers::set_ev_mat<float>( char, char, gtint_t, gtint_t, gtint_t, float, float* );
|
||||
template void testinghelpers::set_ev_mat<double>( char, char, gtint_t, gtint_t, gtint_t, double, double* );
|
||||
template void testinghelpers::set_ev_mat<scomplex>( char, char, gtint_t, gtint_t, gtint_t, scomplex, scomplex* );
|
||||
template void testinghelpers::set_ev_mat<dcomplex>( char, char, gtint_t, gtint_t, gtint_t, dcomplex, dcomplex* );
|
||||
@@ -614,4 +614,76 @@ template void print_matrix<double>( char, gtint_t, gtint_t, double *, gtint_t, c
|
||||
template void print_matrix<scomplex>( char, gtint_t, gtint_t, scomplex *, gtint_t, const char * );
|
||||
template void print_matrix<dcomplex>( char, gtint_t, gtint_t, dcomplex *, gtint_t, const char * );
|
||||
|
||||
|
||||
/*
|
||||
Helper function that returns a string based on the value that is passed
|
||||
The return values are as follows :
|
||||
If datatype is real : "nan", "inf"/"minus_inf", "value", where "value"
|
||||
is the string version of the value that is passed, if it is not nan/inf/-inf.
|
||||
|
||||
If the datatype is complex : The string is concatenated with both the real and
|
||||
imaginary components values, based on analysis done separately to each of them
|
||||
(similar to real datatype).
|
||||
*/
|
||||
template<typename T>
|
||||
std::string get_value_string(T exval)
|
||||
{
|
||||
std::string exval_str;
|
||||
if constexpr (testinghelpers::type_info<T>::is_real)
|
||||
{
|
||||
if(std::isnan(exval))
|
||||
exval_str = "nan";
|
||||
else if(std::isinf(exval))
|
||||
exval_str = (exval >= 0) ? "inf" : "minus_inf";
|
||||
else
|
||||
exval_str = ( exval >= 0) ? std::to_string(int(exval)) : "minus_" + std::to_string(int(std::abs(exval)));
|
||||
}
|
||||
else
|
||||
{
|
||||
if(std::isnan(exval.real))
|
||||
{
|
||||
exval_str = "nan";
|
||||
if(std::isinf(exval.imag))
|
||||
exval_str = exval_str + "pi" + ((exval.imag >= 0) ? "inf" : "minus_inf");
|
||||
else
|
||||
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
|
||||
}
|
||||
else if(std::isnan(exval.imag))
|
||||
{
|
||||
if(std::isinf(exval.real))
|
||||
exval_str = ((exval.real >= 0) ? "inf" : "minus_inf");
|
||||
else
|
||||
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
|
||||
exval_str = exval_str + "pinan";
|
||||
}
|
||||
else if(std::isinf(exval.real))
|
||||
{
|
||||
exval_str = ((exval.real >= 0) ? "inf" : "minus_inf");
|
||||
if(std::isnan(exval.imag))
|
||||
exval_str = exval_str + "pinan";
|
||||
else
|
||||
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
|
||||
}
|
||||
else if(std::isinf(exval.imag))
|
||||
{
|
||||
if(std::isnan(exval.real))
|
||||
exval_str = "nan";
|
||||
else
|
||||
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
|
||||
|
||||
exval_str = exval_str + ((exval.imag >= 0) ? "inf" : "minus_inf");
|
||||
}
|
||||
else
|
||||
{
|
||||
exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real))));
|
||||
exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag))));
|
||||
}
|
||||
}
|
||||
return exval_str;
|
||||
}
|
||||
template std::string testinghelpers::get_value_string( float );
|
||||
template std::string testinghelpers::get_value_string( double );
|
||||
template std::string testinghelpers::get_value_string( scomplex );
|
||||
template std::string testinghelpers::get_value_string( dcomplex );
|
||||
|
||||
} //end of namespace testinghelpers
|
||||
264
gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp
Normal file
264
gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp
Normal file
@@ -0,0 +1,264 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "common/testing_helpers.h"
|
||||
#include "gemm.h"
|
||||
#include "inc/check_error.h"
|
||||
#include "common/wrong_inputs_helpers.h"
|
||||
|
||||
template <typename T>
|
||||
class Gemm_IIT_ERS_Test : public ::testing::Test {};
|
||||
typedef ::testing::Types<float, double, scomplex, dcomplex> TypeParam; // The supported datatypes from BLAS calls for GEMM
|
||||
TYPED_TEST_SUITE(Gemm_IIT_ERS_Test, TypeParam); // Defining individual testsuites based on the datatype support.
|
||||
|
||||
// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h.
|
||||
using namespace testinghelpers::IIT;
|
||||
|
||||
#ifdef TEST_BLAS
|
||||
|
||||
/*
|
||||
Incorrect Input Testing(IIT)
|
||||
|
||||
BLAS exceptions get triggered in the following cases(for GEMM):
|
||||
1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 1)
|
||||
2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 2)
|
||||
3. When m < 0 (info = 3)
|
||||
4. When n < 0 (info = 4)
|
||||
5. When k < 0 (info = 5)
|
||||
6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value
|
||||
7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value
|
||||
8. When ldc < max(1, n) (info = 13)
|
||||
|
||||
*/
|
||||
|
||||
// When info == 1
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transa)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for TRANS value for A.
|
||||
gemm<T>( STORAGE, 'p', TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 2
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transb)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for TRANS value for B.
|
||||
gemm<T>( STORAGE, TRANS, 'p', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 3
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, m_lt_zero)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for m.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 4
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, n_lt_zero)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for n.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 5
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, k_lt_zero)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for k.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 8
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_lda)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for lda.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 10
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldb)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for ldb.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When info == 13
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldc)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
// Call BLIS Gemm with a invalid value for ldc.
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
/*
|
||||
Early Return Scenarios(ERS) :
|
||||
|
||||
The GEMM API is expected to return early in the following cases:
|
||||
|
||||
1. When m == 0.
|
||||
2. When n == 0.
|
||||
3. When (alpha == 0 or k == 0) and beta == 1.
|
||||
|
||||
*/
|
||||
|
||||
// When m is 0
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, m_eq_zero)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
gemm<T>( STORAGE, TRANS, TRANS, 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When n is 0
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, n_eq_zero)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When alpha is 0 and beta is 1
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, alpha_zero_beta_one)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
T alpha, beta;
|
||||
|
||||
testinghelpers::initzero<T>( alpha );
|
||||
testinghelpers::initone<T>( beta );
|
||||
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
// When k is 0 and beta is 1
|
||||
TYPED_TEST(Gemm_IIT_ERS_Test, k_zero_beta_one)
|
||||
{
|
||||
using T = TypeParam;
|
||||
// Defining the C matrix with values for debugging purposes
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>(-10, 10, STORAGE, 'N', N, N, LDC, 'f');
|
||||
|
||||
// Copy so that we check that the elements of C are not modified.
|
||||
std::vector<T> c_ref(c);
|
||||
T alpha, beta;
|
||||
|
||||
testinghelpers::initone<T>( alpha );
|
||||
testinghelpers::initone<T>( beta );
|
||||
|
||||
gemm<T>( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC );
|
||||
// Use bitwise comparison (no threshold).
|
||||
computediff<T>( STORAGE, N, N, c.data(), c_ref.data(), LDC);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
@@ -76,4 +76,63 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n,
|
||||
// check component-wise error.
|
||||
//----------------------------------------------------------
|
||||
computediff<T>( storage, m, n, c.data(), c_ref.data(), ldc, thresh );
|
||||
}
|
||||
|
||||
// Test body used for exception value testing, by iducing an exception value
|
||||
// in the index that is passed for each of the matrices.
|
||||
/*
|
||||
(ai, aj) is the index with corresponding exception value aexval in matrix A.
|
||||
The index is with respect to the assumption that the matrix is column stored,
|
||||
without any transpose. In case of the row-storage and/or transpose, the index
|
||||
is translated from its assumption accordingly.
|
||||
Ex : (2, 3) with storage 'c' and transpose 'n' becomes (3, 2) if storage becomes
|
||||
'r' or transpose becomes 't'.
|
||||
*/
|
||||
// (bi, bj) is the index with corresponding exception value bexval in matrix B.
|
||||
// (ci, cj) is the index with corresponding exception value cexval in matrix C.
|
||||
template<typename T>
|
||||
void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n,
|
||||
gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha,
|
||||
T beta, gtint_t ai, gtint_t aj, T aexval, gtint_t bi, gtint_t bj, T bexval,
|
||||
gtint_t ci, gtint_t cj, T cexval, double thresh )
|
||||
{
|
||||
// Compute the leading dimensions of a, b, and c.
|
||||
gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc );
|
||||
gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc );
|
||||
gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc );
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Initialize matrics with random numbers
|
||||
//----------------------------------------------------------
|
||||
std::vector<T> a = testinghelpers::get_random_matrix<T>( -2, 8, storage, trnsa, m, k, lda );
|
||||
std::vector<T> b = testinghelpers::get_random_matrix<T>( -5, 2, storage, trnsb, k, n, ldb );
|
||||
std::vector<T> c = testinghelpers::get_random_matrix<T>( -3, 5, storage, 'n', m, n, ldc );
|
||||
|
||||
// Inducing exception values onto the matrices based on the indices passed as arguments.
|
||||
// Assumption is that the indices are with respect to the matrices in column storage without
|
||||
// any transpose. In case of difference in storage scheme or transposition, the row and column
|
||||
// indices are appropriately swapped.
|
||||
testinghelpers::set_ev_mat( storage, trnsa, lda, ai, aj, aexval, a.data() );
|
||||
testinghelpers::set_ev_mat( storage, trnsb, ldb, bi, bj, bexval, b.data() );
|
||||
testinghelpers::set_ev_mat( storage, 'n', ldc, ci, cj, cexval, c.data() );
|
||||
|
||||
// Create a copy of c so that we can check reference results.
|
||||
std::vector<T> c_ref(c);
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Call BLIS function
|
||||
//----------------------------------------------------------
|
||||
gemm<T>( storage, trnsa, trnsb, m, n, k, &alpha, a.data(), lda,
|
||||
b.data(), ldb, &beta, c.data(), ldc );
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Call reference implementation.
|
||||
//----------------------------------------------------------
|
||||
testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha,
|
||||
a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc );
|
||||
|
||||
//----------------------------------------------------------
|
||||
// check component-wise error.
|
||||
//----------------------------------------------------------
|
||||
computediff<T>( storage, m, n, c.data(), c_ref.data(), ldc, thresh, true );
|
||||
}
|
||||
356
gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp
Normal file
356
gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp
Normal file
@@ -0,0 +1,356 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
/*
|
||||
The following file contains both the exception value testing(EVT) and the
|
||||
positive accuracy testing of the bli_zgemm_4x4_avx2_k1_nn( ... ) computational
|
||||
kernel. This kernel is invoked from the BLAS layer, and inputs are given
|
||||
in a manner so as to avoid the other code-paths and test only the required
|
||||
kernel.
|
||||
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_gemm.h"
|
||||
|
||||
class ZGemmEVTTest :
|
||||
public ::testing::TestWithParam<std::tuple<char,
|
||||
char,
|
||||
char,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
dcomplex,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
dcomplex,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
dcomplex,
|
||||
dcomplex,
|
||||
dcomplex,
|
||||
gtint_t,
|
||||
gtint_t,
|
||||
gtint_t>> {};
|
||||
|
||||
TEST_P(ZGemmEVTTest, Unit_Tester)
|
||||
{
|
||||
using T = dcomplex;
|
||||
//----------------------------------------------------------
|
||||
// Initialize values from the parameters passed through
|
||||
// test suite instantiation (INSTANTIATE_TEST_SUITE_P).
|
||||
//----------------------------------------------------------
|
||||
// matrix storage format(row major, column major)
|
||||
char storage = std::get<0>(GetParam());
|
||||
// denotes whether matrix a is n,c,t,h
|
||||
char transa = std::get<1>(GetParam());
|
||||
// denotes whether matrix b is n,c,t,h
|
||||
char transb = std::get<2>(GetParam());
|
||||
// matrix size m
|
||||
gtint_t m = std::get<3>(GetParam());
|
||||
// matrix size n
|
||||
gtint_t n = std::get<4>(GetParam());
|
||||
// matrix size k
|
||||
gtint_t k = std::get<5>(GetParam());
|
||||
|
||||
gtint_t ai, aj, bi, bj, ci, cj;
|
||||
T aex, bex, cex;
|
||||
ai = std::get<6>(GetParam());
|
||||
aj = std::get<7>(GetParam());
|
||||
aex = std::get<8>(GetParam());
|
||||
|
||||
bi = std::get<9>(GetParam());
|
||||
bj = std::get<10>(GetParam());
|
||||
bex = std::get<11>(GetParam());
|
||||
|
||||
ci = std::get<12>(GetParam());
|
||||
cj = std::get<13>(GetParam());
|
||||
cex = std::get<14>(GetParam());
|
||||
|
||||
// specifies alpha value
|
||||
T alpha = std::get<15>(GetParam());
|
||||
// specifies beta value
|
||||
T beta = std::get<16>(GetParam());
|
||||
// lda, ldb, ldc increments.
|
||||
// If increments are zero, then the array size matches the matrix size.
|
||||
// If increments are nonnegative, the array size is bigger than the matrix size.
|
||||
gtint_t lda_inc = std::get<17>(GetParam());
|
||||
gtint_t ldb_inc = std::get<18>(GetParam());
|
||||
gtint_t ldc_inc = std::get<19>(GetParam());
|
||||
|
||||
// Set the threshold for the errors:
|
||||
double thresh = 10*m*n*testinghelpers::getEpsilon<T>();
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Call test body using these parameters
|
||||
//----------------------------------------------------------
|
||||
test_gemm<T>( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc,
|
||||
alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh );
|
||||
}
|
||||
|
||||
// Helper classes for printing the test case parameters based on the instantiator
|
||||
// These are mainly used to help with debugging, in case of failures
|
||||
|
||||
// Utility to print the test-case in case of exception value on matrices
|
||||
class ZGemmEVMatPrint {
|
||||
public:
|
||||
std::string operator()(
|
||||
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, gtint_t, gtint_t, dcomplex,
|
||||
gtint_t, gtint_t, dcomplex, gtint_t, gtint_t, dcomplex, dcomplex, dcomplex,
|
||||
gtint_t, gtint_t, gtint_t>> str) const {
|
||||
char sfm = std::get<0>(str.param);
|
||||
char tsa = std::get<1>(str.param);
|
||||
char tsb = std::get<2>(str.param);
|
||||
gtint_t m = std::get<3>(str.param);
|
||||
gtint_t n = std::get<4>(str.param);
|
||||
gtint_t k = std::get<5>(str.param);
|
||||
gtint_t ai, aj, bi, bj, ci, cj;
|
||||
dcomplex aex, bex, cex;
|
||||
ai = std::get<6>(str.param);
|
||||
aj = std::get<7>(str.param);
|
||||
aex = std::get<8>(str.param);
|
||||
|
||||
bi = std::get<9>(str.param);
|
||||
bj = std::get<10>(str.param);
|
||||
bex = std::get<11>(str.param);
|
||||
|
||||
ci = std::get<12>(str.param);
|
||||
cj = std::get<13>(str.param);
|
||||
cex = std::get<14>(str.param);
|
||||
|
||||
dcomplex alpha = std::get<15>(str.param);
|
||||
dcomplex beta = std::get<16>(str.param);
|
||||
gtint_t lda_inc = std::get<17>(str.param);
|
||||
gtint_t ldb_inc = std::get<18>(str.param);
|
||||
gtint_t ldc_inc = std::get<19>(str.param);
|
||||
|
||||
#ifdef TEST_BLAS
|
||||
std::string str_name = "zgemm_";
|
||||
#elif TEST_CBLAS
|
||||
std::string str_name = "cblas_zgemm";
|
||||
#else //#elif TEST_BLIS_TYPED
|
||||
std::string str_name = "blis_zgemm";
|
||||
#endif
|
||||
str_name = str_name + "_" + sfm+sfm+sfm;
|
||||
str_name = str_name + "_" + tsa + tsb;
|
||||
str_name = str_name + "_" + std::to_string(m);
|
||||
str_name = str_name + "_" + std::to_string(n);
|
||||
str_name = str_name + "_" + std::to_string(k);
|
||||
str_name = str_name + "_A" + std::to_string(ai) + std::to_string(aj);
|
||||
str_name = str_name + "_" + testinghelpers::get_value_string(aex);
|
||||
str_name = str_name + "_B" + std::to_string(bi) + std::to_string(bj);
|
||||
str_name = str_name + "_" + testinghelpers::get_value_string(bex);
|
||||
str_name = str_name + "_C" + std::to_string(ci) + std::to_string(cj);
|
||||
str_name = str_name + "_" + testinghelpers::get_value_string(cex);
|
||||
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);
|
||||
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);
|
||||
str_name = str_name + "_" + std::to_string(lda_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldb_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldc_inc);
|
||||
return str_name;
|
||||
}
|
||||
};
|
||||
|
||||
// Utility to print the test-case in case of exception value on matrices
|
||||
class ZGemmEVAlphaBetaPrint {
|
||||
public:
|
||||
std::string operator()(
|
||||
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, gtint_t, gtint_t, dcomplex,
|
||||
gtint_t, gtint_t, dcomplex, gtint_t, gtint_t, dcomplex, dcomplex, dcomplex,
|
||||
gtint_t, gtint_t, gtint_t>> str) const {
|
||||
char sfm = std::get<0>(str.param);
|
||||
char tsa = std::get<1>(str.param);
|
||||
char tsb = std::get<2>(str.param);
|
||||
gtint_t m = std::get<3>(str.param);
|
||||
gtint_t n = std::get<4>(str.param);
|
||||
gtint_t k = std::get<5>(str.param);
|
||||
|
||||
dcomplex alpha = std::get<15>(str.param);
|
||||
dcomplex beta = std::get<16>(str.param);
|
||||
gtint_t lda_inc = std::get<17>(str.param);
|
||||
gtint_t ldb_inc = std::get<18>(str.param);
|
||||
gtint_t ldc_inc = std::get<19>(str.param);
|
||||
|
||||
#ifdef TEST_BLAS
|
||||
std::string str_name = "zgemm_";
|
||||
#elif TEST_CBLAS
|
||||
std::string str_name = "cblas_zgemm";
|
||||
#else //#elif TEST_BLIS_TYPED
|
||||
std::string str_name = "blis_zgemm";
|
||||
#endif
|
||||
str_name = str_name + "_" + sfm+sfm+sfm;
|
||||
str_name = str_name + "_" + tsa + tsb;
|
||||
str_name = str_name + "_" + std::to_string(m);
|
||||
str_name = str_name + "_" + std::to_string(n);
|
||||
str_name = str_name + "_" + std::to_string(k);
|
||||
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);
|
||||
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);
|
||||
str_name = str_name + "_" + std::to_string(lda_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldb_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldc_inc);
|
||||
return str_name;
|
||||
}
|
||||
};
|
||||
|
||||
static double NaN = std::numeric_limits<double>::quiet_NaN();
|
||||
static double Inf = std::numeric_limits<double>::infinity();
|
||||
|
||||
// Exception value testing(on matrices)
|
||||
|
||||
/*
|
||||
For the bli_zgemm_4x4_avx2_k1_nn kernel, the main and fringe dimensions are as follows:
|
||||
For m : Main = { 4 }, fringe = { 2, 1 }
|
||||
For n : Main = { 4 }, fringe = { 2, 1 }
|
||||
|
||||
Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched
|
||||
separately, since if m/n is 1, the inputs are redirected to ZGEMV.
|
||||
|
||||
*/
|
||||
|
||||
// Testing for the main loop case for m and n
|
||||
// The kernel uses 2 loads and 4 broadcasts. The exception values
|
||||
// are induced at one index individually for each of the loads.
|
||||
// They are also induced in the broadcast direction at two places.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
bli_zgemm_4x4_avx2_k1_nn_evt_mat_main,
|
||||
ZGemmEVTTest,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS
|
||||
,'r'
|
||||
#endif
|
||||
), // storage format
|
||||
::testing::Values('n'), // transa
|
||||
::testing::Values('n'), // transb
|
||||
::testing::Values(gtint_t(4)), // m
|
||||
::testing::Values(gtint_t(4)), // n
|
||||
::testing::Values(gtint_t(1)), // k
|
||||
::testing::Values(gtint_t(1), gtint_t(3)), // ai
|
||||
::testing::Values(gtint_t(0)), // aj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval
|
||||
::testing::Values(gtint_t(0)), // bi
|
||||
::testing::Values(gtint_t(0), gtint_t(2)), // bj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval
|
||||
::testing::Values(gtint_t(0), gtint_t(2)), // ci
|
||||
::testing::Values(gtint_t(1), gtint_t(3)), // cj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval
|
||||
::testing::Values(dcomplex{-2.2, 3.3}), // alpha
|
||||
::testing::Values(dcomplex{1.2, -2.3}), // beta
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of a
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0)) // increment to the leading dim of c
|
||||
),
|
||||
::ZGemmEVMatPrint()
|
||||
);
|
||||
|
||||
// Testing the fringe cases
|
||||
// Fringe case minimum size is 2 along both m and n.
|
||||
// Invloves only one load(AVX2 or (AVX2+SSE)). Thus,
|
||||
// the exception values are induced at the first and second indices of the
|
||||
// column vector A and row vector B.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
bli_zgemm_4x4_avx2_k1_nn_evt_mat_fringe,
|
||||
ZGemmEVTTest,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS
|
||||
,'r'
|
||||
#endif
|
||||
), // storage format
|
||||
::testing::Values('n'), // transa
|
||||
::testing::Values('n'), // transb
|
||||
::testing::Values(gtint_t(2), gtint_t(3)), // m
|
||||
::testing::Values(gtint_t(2), gtint_t(3)), // n
|
||||
::testing::Values(gtint_t(1)), // k
|
||||
::testing::Values(gtint_t(0), gtint_t(1)), // ai
|
||||
::testing::Values(gtint_t(0)), // aj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval
|
||||
::testing::Values(gtint_t(0)), // bi
|
||||
::testing::Values(gtint_t(0), gtint_t(1)), // bj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval
|
||||
::testing::Values(gtint_t(0), gtint_t(1)), // ci
|
||||
::testing::Values(gtint_t(0), gtint_t(1)), // cj
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval
|
||||
::testing::Values(dcomplex{-2.2, 3.3}), // alpha
|
||||
::testing::Values(dcomplex{1.2, -2.3}), // beta
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of a
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0)) // increment to the leading dim of c
|
||||
),
|
||||
::ZGemmEVMatPrint()
|
||||
);
|
||||
|
||||
// Exception value testing(on alpha and beta)
|
||||
// Alpha and beta are set to exception values
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
bli_zgemm_4x4_avx2_k1_nn_evt_alphabeta,
|
||||
ZGemmEVTTest,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS
|
||||
,'r'
|
||||
#endif
|
||||
), // storage format
|
||||
::testing::Values('n'), // transa
|
||||
::testing::Values('n'), // transb
|
||||
::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // m
|
||||
::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // n
|
||||
::testing::Values(gtint_t(1)), // k
|
||||
::testing::Values(gtint_t(0)), // ai
|
||||
::testing::Values(gtint_t(0)), // aj
|
||||
::testing::Values(dcomplex{0.0, 0.0}),
|
||||
::testing::Values(gtint_t(0)), // bi
|
||||
::testing::Values(gtint_t(0)), // bj
|
||||
::testing::Values(dcomplex{0.0, 0.0}),
|
||||
::testing::Values(gtint_t(0)), // ci
|
||||
::testing::Values(gtint_t(0)), // cj
|
||||
::testing::Values(dcomplex{0.0, 0.0}),
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // alpha
|
||||
::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0},
|
||||
dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // beta
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of a
|
||||
::testing::Values(gtint_t(0)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0)) // increment to the leading dim of c
|
||||
),
|
||||
::ZGemmEVAlphaBetaPrint()
|
||||
);
|
||||
@@ -35,7 +35,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_gemm.h"
|
||||
|
||||
class ZGemmTest :
|
||||
class ZGemmAccTest :
|
||||
public ::testing::TestWithParam<std::tuple<char,
|
||||
char,
|
||||
char,
|
||||
@@ -48,7 +48,7 @@ class ZGemmTest :
|
||||
gtint_t,
|
||||
gtint_t>> {};
|
||||
|
||||
TEST_P(ZGemmTest, RandomData)
|
||||
TEST_P(ZGemmAccTest, Unit_Tester)
|
||||
{
|
||||
using T = dcomplex;
|
||||
//----------------------------------------------------------
|
||||
@@ -87,7 +87,7 @@ TEST_P(ZGemmTest, RandomData)
|
||||
test_gemm<T>( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh );
|
||||
}
|
||||
|
||||
class ZGemmTestPrint {
|
||||
class ZGemmAccPrint {
|
||||
public:
|
||||
std::string operator()(
|
||||
testing::TestParamInfo<std::tuple<char, char, char, gtint_t, gtint_t, gtint_t, dcomplex, dcomplex, gtint_t, gtint_t, gtint_t>> str) const {
|
||||
@@ -114,12 +114,8 @@ public:
|
||||
str_name = str_name + "_" + std::to_string(m);
|
||||
str_name = str_name + "_" + std::to_string(n);
|
||||
str_name = str_name + "_" + std::to_string(k);
|
||||
std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real))));
|
||||
alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag)))));
|
||||
std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real))));
|
||||
beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag)))));
|
||||
str_name = str_name + "_a" + alpha_str;
|
||||
str_name = str_name + "_b" + beta_str;
|
||||
str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);;
|
||||
str_name = str_name + "_b" + testinghelpers::get_value_string(beta);;
|
||||
str_name = str_name + "_" + std::to_string(lda_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldb_inc);
|
||||
str_name = str_name + "_" + std::to_string(ldc_inc);
|
||||
@@ -127,10 +123,41 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Unit testing for bli_zgemm_4x4_avx2_k1_nn kernel
|
||||
/* From the BLAS layer(post parameter checking), the inputs will be redirected to this kernel
|
||||
if m != 1, n !=1 and k == 1 */
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
bli_zgemm_4x4_avx2_k1_nn,
|
||||
ZGemmAccTest,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS
|
||||
,'r'
|
||||
#endif
|
||||
), // storage format
|
||||
::testing::Values('n'), // transa
|
||||
::testing::Values('n'), // transb
|
||||
::testing::Range(gtint_t(2), gtint_t(8), 1), // m
|
||||
::testing::Range(gtint_t(2), gtint_t(8), 1), // n
|
||||
::testing::Values(gtint_t(1)), // k
|
||||
::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0},
|
||||
dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9},
|
||||
dcomplex{0.0, 0.0}), // alpha
|
||||
::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0},
|
||||
dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9},
|
||||
dcomplex{0.0, 0.0}), // beta
|
||||
::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a
|
||||
::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c
|
||||
),
|
||||
::ZGemmAccPrint()
|
||||
);
|
||||
|
||||
// Black box testing.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Blackbox,
|
||||
ZGemmTest,
|
||||
ZGemmAccTest,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS
|
||||
@@ -148,5 +175,5 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c
|
||||
),
|
||||
::ZGemmTestPrint()
|
||||
::ZGemmAccPrint()
|
||||
);
|
||||
@@ -88,10 +88,10 @@ public:
|
||||
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
|
||||
str_name = str_name + "_" + incx_str;
|
||||
str_name = str_name + "_i" + std::to_string(i);
|
||||
std::string iexval_str = getValueString(iexval);
|
||||
std::string iexval_str = testinghelpers::get_value_string(iexval);
|
||||
str_name = str_name + "_" + iexval_str;
|
||||
str_name = str_name + "_j" + std::to_string(j);
|
||||
std::string jexval_str = getValueString(jexval);
|
||||
std::string jexval_str = testinghelpers::get_value_string(jexval);
|
||||
str_name = str_name + "_" + jexval_str;
|
||||
return str_name;
|
||||
}
|
||||
|
||||
@@ -88,10 +88,10 @@ public:
|
||||
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
|
||||
str_name = str_name + "_" + incx_str;
|
||||
str_name = str_name + "_i" + std::to_string(i);
|
||||
std::string iexval_str = "_Re_" + getValueString(iexval.real) + "_Im_" + getValueString(iexval.imag);
|
||||
std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag);
|
||||
str_name = str_name + iexval_str;
|
||||
str_name = str_name + "_j" + std::to_string(j);
|
||||
std::string jexval_str = "_Re_" + getValueString(jexval.real) + "_Im_" + getValueString(jexval.imag);
|
||||
std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag);
|
||||
str_name = str_name + jexval_str;
|
||||
return str_name;
|
||||
}
|
||||
|
||||
@@ -88,10 +88,10 @@ public:
|
||||
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
|
||||
str_name = str_name + "_" + incx_str;
|
||||
str_name = str_name + "_i" + std::to_string(i);
|
||||
std::string iexval_str = "_Re_" + getValueString(iexval.real) + "_Im_" + getValueString(iexval.imag);
|
||||
std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag);
|
||||
str_name = str_name + iexval_str;
|
||||
str_name = str_name + "_j" + std::to_string(j);
|
||||
std::string jexval_str = "_Re_" + getValueString(jexval.real) + "_Im_" + getValueString(jexval.imag);
|
||||
std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag);
|
||||
str_name = str_name + jexval_str;
|
||||
return str_name;
|
||||
}
|
||||
|
||||
@@ -88,10 +88,10 @@ public:
|
||||
std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx));
|
||||
str_name = str_name + "_" + incx_str;
|
||||
str_name = str_name + "_i" + std::to_string(i);
|
||||
std::string iexval_str = getValueString(iexval);
|
||||
std::string iexval_str = testinghelpers::get_value_string(iexval);
|
||||
str_name = str_name + "_" + iexval_str;
|
||||
str_name = str_name + "_j" + std::to_string(j);
|
||||
std::string jexval_str = getValueString(jexval);
|
||||
std::string jexval_str = testinghelpers::get_value_string(jexval);
|
||||
str_name = str_name + "_" + jexval_str;
|
||||
return str_name;
|
||||
}
|
||||
|
||||
@@ -98,19 +98,4 @@ void test_nrm2( gtint_t n, gtint_t incx, gtint_t i, T iexval, gtint_t j = 0, T j
|
||||
//----------------------------------------------------------
|
||||
// Compare using NaN/Inf checks.
|
||||
computediff<RT>( norm, norm_ref, true );
|
||||
}
|
||||
|
||||
// Helper function that returns a string with the correct NaN/Inf printing
|
||||
// so that we can print the test names correctly from using parametrized testing.
|
||||
template<typename T>
|
||||
std::string getValueString(T exval)
|
||||
{
|
||||
std::string exval_str;
|
||||
if(std::isnan(exval))
|
||||
exval_str = "nan";
|
||||
else if(std::isinf(exval))
|
||||
exval_str = (exval > 0) ? "inf" : "minus_inf";
|
||||
else
|
||||
exval_str = ( exval > 0) ? std::to_string(int(exval)) : "minus_" + std::to_string(int(std::abs(exval)));
|
||||
return exval_str;
|
||||
}
|
||||
Reference in New Issue
Block a user