mirror of
https://github.com/amd/blis.git
synced 2026-05-13 02:25:39 +00:00
Gtestsuite Update for Pack and Compute Extension APIs
- Pack and compute are now compared against GEMM operation of reference library when MKL is not used as a reference. - For the case where both A and B are unpacked, the reference GEMM is invoked with a unit-alpha scalar. - If MKL is used as reference, then these APIs are compared against pack and compute operations of MKL. - Updated description in ref_gemm_compute.cpp to reflect this behavior. AMD-Internal: [CPUPL-4084] Change-Id: Id0521c9cad8743a7ae471a7f3c547ceb67191f86
This commit is contained in:
committed by
Arnav Sharma
parent
ffa8f584be
commit
dd1cf23090
@@ -53,11 +53,20 @@
|
||||
* Alpha and beta are scalars, and A, B and C are matrices, with A
|
||||
* an m by k matrix, B a k by n matrix and C an m by n matrix,
|
||||
* where either A or B or both may be scaled by alpha and reordered.
|
||||
*
|
||||
* NOTE:
|
||||
* - For MKL comparing against pack and compute APIs.
|
||||
* - For all other reference libraries (except MKL), we compare the result of
|
||||
* BLIS pack and compute against the GEMM operation of the reference library.
|
||||
* In case when both A & B are unpacked, we do not invoke xgemm_pack() thus,
|
||||
* not computing alpha * X operation. So to handle this case, we pass
|
||||
* unit-alpha to the reference GEMM.
|
||||
* ==========================================================================
|
||||
*/
|
||||
|
||||
namespace testinghelpers {
|
||||
|
||||
#ifdef REF_IS_MKL
|
||||
template <typename T>
|
||||
void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha,
|
||||
T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc)
|
||||
@@ -103,10 +112,10 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template.");
|
||||
throw std::runtime_error("Error in ref_gemm_compute.cpp: Invalid typename is passed function template.");
|
||||
}
|
||||
if( !ref_cblas_gemm_compute ) {
|
||||
throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found.");
|
||||
throw std::runtime_error("Error in ref_gemm_compute.cpp: Function pointer == 0 -- symbol not found.");
|
||||
}
|
||||
|
||||
err_t err = BLIS_SUCCESS;
|
||||
@@ -161,7 +170,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
|
||||
|
||||
ref_cblas_gemm_compute( cblas_order, cblas_packed, cblas_transb,
|
||||
m, n, k, aBuffer, lda, bp, ldb, beta, cp, ldc );
|
||||
|
||||
|
||||
bli_free_user( aBuffer );
|
||||
}
|
||||
else if ( ( pckb == 'P' || pckb == 'p' ) )
|
||||
@@ -181,7 +190,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
|
||||
|
||||
ref_cblas_gemm_compute( cblas_order, cblas_transa, cblas_packed,
|
||||
m, n, k, ap, lda, bBuffer, ldb, beta, cp, ldc );
|
||||
|
||||
|
||||
bli_free_user( bBuffer );
|
||||
}
|
||||
else
|
||||
@@ -190,6 +199,57 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
|
||||
m, n, k, ap, lda, bp, ldb, beta, cp, ldc );
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename T>
|
||||
void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha,
|
||||
T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc)
|
||||
{
|
||||
// throw std::runtime_error("Error in ref_gemm_compute.cpp: Reference is only defined for MKL. Please use MKL as reference library.");
|
||||
enum CBLAS_ORDER cblas_order;
|
||||
enum CBLAS_TRANSPOSE cblas_transa;
|
||||
enum CBLAS_TRANSPOSE cblas_transb;
|
||||
|
||||
char_to_cblas_order( storage, &cblas_order );
|
||||
char_to_cblas_trans( trnsa, &cblas_transa );
|
||||
char_to_cblas_trans( trnsb, &cblas_transb );
|
||||
|
||||
using scalar_t = std::conditional_t<testinghelpers::type_info<T>::is_complex, T&, T>;
|
||||
typedef void (*Fptr_ref_cblas_gemm)( const CBLAS_ORDER, const CBLAS_TRANSPOSE, const CBLAS_TRANSPOSE,
|
||||
const f77_int, const f77_int, const f77_int, const scalar_t, const T*, f77_int,
|
||||
const T*, f77_int, const scalar_t, T*, f77_int);
|
||||
Fptr_ref_cblas_gemm ref_cblas_gemm;
|
||||
|
||||
// Call C function
|
||||
/* Check the typename T passed to this function template and call respective function.*/
|
||||
if (typeid(T) == typeid(float))
|
||||
{
|
||||
ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_sgemm");
|
||||
}
|
||||
else if (typeid(T) == typeid(double))
|
||||
{
|
||||
ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_dgemm");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template.");
|
||||
}
|
||||
if( !ref_cblas_gemm ) {
|
||||
throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found.");
|
||||
}
|
||||
|
||||
if ( ( pcka == 'U' or pcka == 'u' ) && ( pckb == 'U' or pckb == 'u' ) )
|
||||
{
|
||||
T unit_alpha = 1.0;
|
||||
ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb,
|
||||
m, n, k, unit_alpha, ap, lda, bp, ldb, beta, cp, ldc );
|
||||
}
|
||||
else
|
||||
{
|
||||
ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb,
|
||||
m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc );
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Explicit template instantiations
|
||||
template void ref_gemm_compute<float>(char, char, char, char, char, gtint_t, gtint_t, gtint_t, float,
|
||||
|
||||
Reference in New Issue
Block a user