From dd1cf230905e1bb1f3d99a31925d09db8490b2f3 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Thu, 26 Oct 2023 14:35:05 +0530 Subject: [PATCH] Gtestsuite Update for Pack and Compute Extension APIs - Pack and compute are now compared against GEMM operation of reference library when MKL is not used as a reference. - For the case where both A and B are unpacked, the reference GEMM is invoked with a unit-alpha scalar. - If MKL is used as reference, then these APIs are compared against pack and compute operations of MKL. - Updated description in ref_gemm_compute.cpp to reflect this behavior. AMD-Internal: [CPUPL-4084] Change-Id: Id0521c9cad8743a7ae471a7f3c547ceb67191f86 --- .../src/level3/ref_gemm_compute.cpp | 68 +++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp index 2b15ffea2..dd069fcd8 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp @@ -53,11 +53,20 @@ * Alpha and beta are scalars, and A, B and C are matrices, with A * an m by k matrix, B a k by n matrix and C an m by n matrix, * where either A or B or both may be scaled by alpha and reordered. + * + * NOTE: + * - For MKL comparing against pack and compute APIs. + * - For all other reference libraries (except MKL), we compare the result of + * BLIS pack and compute against the GEMM operation of the reference library. + * In case when both A & B are unpacked, we do not invoke xgemm_pack() thus, + * not computing alpha * X operation. So to handle this case, we pass + * unit-alpha to the reference GEMM. * ========================================================================== */ namespace testinghelpers { +#ifdef REF_IS_MKL template void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc) @@ -103,10 +112,10 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb } else { - throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template."); + throw std::runtime_error("Error in ref_gemm_compute.cpp: Invalid typename is passed function template."); } if( !ref_cblas_gemm_compute ) { - throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found."); + throw std::runtime_error("Error in ref_gemm_compute.cpp: Function pointer == 0 -- symbol not found."); } err_t err = BLIS_SUCCESS; @@ -161,7 +170,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb ref_cblas_gemm_compute( cblas_order, cblas_packed, cblas_transb, m, n, k, aBuffer, lda, bp, ldb, beta, cp, ldc ); - + bli_free_user( aBuffer ); } else if ( ( pckb == 'P' || pckb == 'p' ) ) @@ -181,7 +190,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb ref_cblas_gemm_compute( cblas_order, cblas_transa, cblas_packed, m, n, k, ap, lda, bBuffer, ldb, beta, cp, ldc ); - + bli_free_user( bBuffer ); } else @@ -190,6 +199,57 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb m, n, k, ap, lda, bp, ldb, beta, cp, ldc ); } } +#else +template +void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb, gtint_t m, gtint_t n, gtint_t k, T alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T beta, T* cp, gtint_t ldc) +{ + // throw std::runtime_error("Error in ref_gemm_compute.cpp: Reference is only defined for MKL. Please use MKL as reference library."); + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trnsa, &cblas_transa ); + char_to_cblas_trans( trnsb, &cblas_transb ); + + using scalar_t = std::conditional_t::is_complex, T&, T>; + typedef void (*Fptr_ref_cblas_gemm)( const CBLAS_ORDER, const CBLAS_TRANSPOSE, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const f77_int, const scalar_t, const T*, f77_int, + const T*, f77_int, const scalar_t, T*, f77_int); + Fptr_ref_cblas_gemm ref_cblas_gemm; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_sgemm"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_gemm = (Fptr_ref_cblas_gemm)refCBLASModule.loadSymbol("cblas_dgemm"); + } + else + { + throw std::runtime_error("Error in ref_gemm.cpp: Invalid typename is passed function template."); + } + if( !ref_cblas_gemm ) { + throw std::runtime_error("Error in ref_gemm.cpp: Function pointer == 0 -- symbol not found."); + } + + if ( ( pcka == 'U' or pcka == 'u' ) && ( pckb == 'U' or pckb == 'u' ) ) + { + T unit_alpha = 1.0; + ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb, + m, n, k, unit_alpha, ap, lda, bp, ldb, beta, cp, ldc ); + } + else + { + ref_cblas_gemm( cblas_order, cblas_transa, cblas_transb, + m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + } +} +#endif // Explicit template instantiations template void ref_gemm_compute(char, char, char, char, char, gtint_t, gtint_t, gtint_t, float,