Gtestsuite Update for Pack and Compute Extension APIs

- Pack and compute are now compared against GEMM operation of reference
  library when MKL is not used as a reference.
- For the case where both A and B are unpacked, the reference GEMM is
  invoked with a unit-alpha scalar.
- If MKL is used as reference, then these APIs are compared against pack
  and compute operations of MKL.
- Updated description in ref_gemm_compute.cpp to reflect this behavior.

AMD-Internal: [CPUPL-4084]
Change-Id: Id0521c9cad8743a7ae471a7f3c547ceb67191f86
This commit is contained in:
Arnav Sharma
2023-10-26 14:35:05 +05:30
committed by Arnav Sharma
parent ffa8f584be
commit dd1cf23090

View File

@@ -53,11 +53,20 @@
* Alpha and beta are scalars, and A, B and C are matrices, with A
* an m by k matrix, B a k by n matrix and C an m by n matrix,
* where either A or B or both may be scaled by alpha and reordered.
*
* NOTE:
* - For MKL comparing against pack and compute APIs.
* - For all other reference libraries (except MKL), we compare the result of
* BLIS pack and compute against the GEMM operation of the reference library.
* In case when both A & B are unpacked, we do not invoke xgemm_pack() thus,
* not computing alpha * X operation. So to handle this case, we pass
* unit-alpha to the reference GEMM.
* ==========================================================================
*/
namespace testinghelpers {
#ifdef REF_IS_MKL
template <typename T>
void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha,
T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc)
@@ -103,10 +112,10 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
}
else
{
throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template.");
throw std::runtime_error("Error in ref_gemm_compute.cpp: Invalid typename is passed function template.");
}
if( !ref_cblas_gemm_compute ) {
throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found.");
throw std::runtime_error("Error in ref_gemm_compute.cpp: Function pointer == 0 -- symbol not found.");
}
err_t err = BLIS_SUCCESS;
@@ -161,7 +170,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
ref_cblas_gemm_compute( cblas_order, cblas_packed, cblas_transb,
m, n, k, aBuffer, lda, bp, ldb, beta, cp, ldc );
bli_free_user( aBuffer );
}
else if ( ( pckb == 'P' || pckb == 'p' ) )
@@ -181,7 +190,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
ref_cblas_gemm_compute( cblas_order, cblas_transa, cblas_packed,
m, n, k, ap, lda, bBuffer, ldb, beta, cp, ldc );
bli_free_user( bBuffer );
}
else
@@ -190,6 +199,57 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb
m, n, k, ap, lda, bp, ldb, beta, cp, ldc );
}
}
#else
template <typename T>
void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha,
T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc)
{
// throw std::runtime_error("Error in ref_gemm_compute.cpp: Reference is only defined for MKL. Please use MKL as reference library.");
enum CBLAS_ORDER cblas_order;
enum CBLAS_TRANSPOSE cblas_transa;
enum CBLAS_TRANSPOSE cblas_transb;
char_to_cblas_order( storage, &cblas_order );
char_to_cblas_trans( trnsa, &cblas_transa );
char_to_cblas_trans( trnsb, &cblas_transb );
using scalar_t = std::conditional_t<testinghelpers::type_info<T>::is_complex, T&, T>;
typedef void (*Fptr_ref_cblas_gemm)( const CBLAS_ORDER, const CBLAS_TRANSPOSE, const CBLAS_TRANSPOSE,
const f77_int, const f77_int, const f77_int, const scalar_t, const T*, f77_int,
const T*, f77_int, const scalar_t, T*, f77_int);
Fptr_ref_cblas_gemm ref_cblas_gemm;
// Call C function
/* Check the typename T passed to this function template and call respective function.*/
if (typeid(T) == typeid(float))
{
ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_sgemm");
}
else if (typeid(T) == typeid(double))
{
ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_dgemm");
}
else
{
throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template.");
}
if( !ref_cblas_gemm ) {
throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found.");
}
if ( ( pcka == 'U' or pcka == 'u' ) && ( pckb == 'U' or pckb == 'u' ) )
{
T unit_alpha = 1.0;
ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb,
m, n, k, unit_alpha, ap, lda, bp, ldb, beta, cp, ldc );
}
else
{
ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb,
m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc );
}
}
#endif
// Explicit template instantiations
template void ref_gemm_compute<float>(char, char, char, char, char, gtint_t, gtint_t, gtint_t, float,