AVX512 optimizations for CGEMM(rank-1 kernel)

- Implemented an AVX512 rank-1 kernel that is
  expected to handle column-major storage schemes
  of A, B and C(without transposition) when k = 1.

- This kernel is single-threaded, and acts as a direct
  call from the BLAS layer for its compatible inputs.

- Defined custom BLAS and BLIS_IMPLI layers for CGEMM
  (instead of using the macro definition), in order to
  integrate the call to this kernel at runtime(based on
  the corresponding architecture and input constraints).

- Added unit-tests for functional and memory testing of the
  kernel.

- Updated the ZEN5 context to include the AVX512 CGEMM
  SUP kernels, with its cache-blocking parameters.

AMD-Internal: [CPUPL-6498]
Change-Id: I42a66c424325bd117ceb38970726a05e2896a46b
This commit is contained in:
Vignesh Balasubramanian
2025-03-04 01:02:40 +05:30
parent 07df9f471e
commit c4b84601da
9 changed files with 2522 additions and 16 deletions

View File

@@ -319,12 +319,12 @@ void bli_cntx_init_zen5( cntx_t* cntx )
BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE,
BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE,
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE,
BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cd_zen4_asm_12x4m, FALSE,
@@ -340,11 +340,11 @@ void bli_cntx_init_zen5( cntx_t* cntx )
// Initialize level-3 sup blocksize objects with architecture-specific
// values.
// s d c z
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 3, 12 );
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 );
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 );
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 384, 128, 64 );
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4032, 2040, 1020 );
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 24, 12 );
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 4, 4 );
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 120, 48 );
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 384, 512, 64 );
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4032, 4080, 1020 );
// Update the context with the current architecture's register and cache
// blocksizes for small/unpacked level-3 problems.

View File

@@ -1435,7 +1435,354 @@ void zgemm_
#endif
}
#endif
INSERT_GENTFUNC_BLAS_SC( gemm, gemm )
void cgemm_blis_impl
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const scomplex* alpha,
const scomplex* a, const f77_int* lda,
const scomplex* b, const f77_int* ldb,
const scomplex* beta,
scomplex* c, const f77_int* ldc
)
{
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
/* Initialize BLIS. */
bli_init_auto();
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *transa, *transb, *m, *n, *k,
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(c),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Quick return if possible. */
if ( *m == 0 || *n == 0 || (( PASTEMAC(c,eq0)( *alpha ) || *k == 0)
&& PASTEMAC(c,eq1)( *beta ) ))
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(c,eq0)( *alpha ))
{
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
PASTEMAC2(c,scalm,_ex)( BLIS_NO_CONJUGATE,
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
m0,
n0,
(scomplex*) beta,
(scomplex*) c, rs_c, cs_c,
NULL, NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1( *m, m0 );
bli_convert_blas_dim1( *n, n0 );
bli_convert_blas_dim1( *k, k0 );
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
/* Call GEMV when m == 1 or n == 1 with the context set
to an uninitialized void pointer i.e. ((void *)0)*/
if (n0 == 1)
{
if (bli_is_notrans(blis_transa))
{
bli_cgemv_unf_var2
(
blis_transa,
bli_extract_conj(blis_transb),
m0, k0,
(scomplex *)alpha,
(scomplex *)a, rs_a, cs_a,
(scomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(scomplex *)beta,
c, rs_c,
((void *)0)
);
}
else
{
bli_cgemv_unf_var1
(
blis_transa,
bli_extract_conj(blis_transb),
k0, m0,
(scomplex *)alpha,
(scomplex *)a, rs_a, cs_a,
(scomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(scomplex *)beta,
c, rs_c,
((void *)0)
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
bli_finalize_auto();
return;
}
else if (m0 == 1)
{
if (bli_is_notrans(blis_transb))
{
bli_cgemv_unf_var1
(
blis_transb,
bli_extract_conj(blis_transa),
n0, k0,
(scomplex *)alpha,
(scomplex *)b, cs_b, rs_b,
(scomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(scomplex *)beta,
c, cs_c,
((void *)0)
);
}
else
{
bli_cgemv_unf_var2
(
blis_transb,
bli_extract_conj(blis_transa),
k0, n0,
(scomplex *)alpha,
(scomplex *)b, cs_b, rs_b,
(scomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(scomplex *)beta,
c, cs_c,
((void *)0)
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
bli_finalize_auto();
return;
}
// This function is invoked on all architectures including 'generic'.
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
if (bli_cpuid_is_avx2fma3_supported() == FALSE)
{
// This code is duplicated below, however we don't want to move it out of
// this IF block as we want to avoid object initialization until required.
// Also this is temporary fix which will be replaced later.
const num_t dt = BLIS_SCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (scomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (scomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (scomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (scomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
// Will call parallelized cgemm code - sup & native
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
(
&alphao,
&ao,
&bo,
&betao,
&co,
NULL,
NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/*
Invoking the API for input sizes with k = 1.
- The API is single-threaded.
- The input constraints are that k should be 1, and transa and transb
should be N and N respectively.
*/
#if defined(BLIS_KERNELS_ZEN4)
if( ( k0 == 1 ) && bli_is_notrans( blis_transa ) && bli_is_notrans( blis_transb ) )
{
arch_t arch_id = bli_arch_query_id();
if ( ( arch_id == BLIS_ARCH_ZEN4 ) || ( arch_id == BLIS_ARCH_ZEN5 ) )
{
bli_cgemm_32x4_avx512_k1_nn
(
m0, n0, k0,
(scomplex*)alpha,
(scomplex*)a, *lda,
(scomplex*)b, *ldb,
(scomplex*)beta,
c, *ldc
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
}
#endif
const num_t dt = BLIS_SCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (scomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (scomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (scomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (scomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
if (status == BLIS_SUCCESS)
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
// fall back on native path when cgemm is not handled in sup path.
//bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
/* Default to using native execution. */
ind_t im = BLIS_NAT;
/* As each matrix operand has a complex storage datatype, try to get an
induced method (if one is available and enabled). NOTE: Allowing
precisions to vary while using 1m, which is what we do here, is unique
to gemm; other level-3 operations use 1m only if all storage datatypes
are equal (and they ignore the computation precision). */
/* Find the highest priority induced method that is both enabled and
available for the current operation. (If an induced method is
available but not enabled, or simply unavailable, BLIS_NAT will
be returned here.) */
im = bli_gemmind_find_avail( dt );
/* Obtain a valid context from the gks using the induced
method id determined above. */
cntx_t* cntx = bli_gks_query_ind_cntx( im, dt );
rntm_t rntm_l;
bli_rntm_init_from_global( &rntm_l );
/* Invoke the operation's front-end and request the default control tree. */
PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL );
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
}// end of cgemm_
#ifdef BLIS_ENABLE_BLAS
void cgemm_
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const scomplex* alpha,
const scomplex* a, const f77_int* lda,
const scomplex* b, const f77_int* ldb,
const scomplex* beta,
scomplex* c, const f77_int* ldc
)
{
cgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
#if defined(BLIS_KERNELS_ZEN4)
arch_t id = bli_arch_query_id();
if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4)
{
bli_zero_zmm();
}
#endif
}
#endif
INSERT_GENTFUNC_BLAS_S( gemm, gemm )
void dzgemm_blis_impl
(

View File

@@ -71,6 +71,10 @@ GENTFUNC( dcomplex, z, blasname, blisname )
\
GENTFUNC( scomplex, c, blasname, blisname )
#define INSERT_GENTFUNC_BLAS_S( blasname, blisname ) \
\
GENTFUNC( float, s, blasname, blisname )
// -- Basic one-operand macro with real domain only --

View File

@@ -40,6 +40,7 @@
#ifdef AOCL_DEV
#define K_bli_cgemm_32x4_avx512_k1_nn 1
#define K_bli_cgemmsup_cv_zen4_asm_24x4m 1
#define K_bli_cgemmsup_cv_zen4_asm_24x3m 1
#define K_bli_cgemmsup_cv_zen4_asm_24x2m 1

View File

@@ -99,8 +99,10 @@ TEST_P( cgemmGeneric, API )
thresh = adj*testinghelpers::getEpsilon<T>();
}
else
thresh = (3*k+1)*testinghelpers::getEpsilon<T>();
{
double adj = 4.0;
thresh = adj*(3*k+1)*testinghelpers::getEpsilon<T>();
}
//----------------------------------------------------------
// Call test body using these parameters
//----------------------------------------------------------
@@ -284,3 +286,30 @@ INSTANTIATE_TEST_SUITE_P(
),
::gemmGenericPrint<scomplex>()
);
INSTANTIATE_TEST_SUITE_P(
K_1,
cgemmGeneric,
::testing::Combine(
::testing::Values('c'
#ifndef TEST_BLAS_LIKE
,'r'
#endif
), // storage format
::testing::Values('n'), // transa
::testing::Values('n'), // transb
::testing::Range(gtint_t(2), gtint_t(63), 1), // m
::testing::Range(gtint_t(2), gtint_t(9), 1), // n
::testing::Values(gtint_t(1)), // k
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
scomplex{0.0, 1.0}, scomplex{2.1, -1.9},
scomplex{0.0, 0.0}), // alpha
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
scomplex{0.0, 1.0}, scomplex{2.1, -1.9},
scomplex{0.0, 0.0}), // beta
::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a
::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of b
::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c
),
::gemmGenericPrint<scomplex>()
);

View File

@@ -1572,3 +1572,94 @@ INSTANTIATE_TEST_SUITE_P(
);
#endif
#endif
// Function pointer specific to cgemm kernel that handles
// special case where k=1.
typedef void (*cgemm_k1_kernel)
(
dim_t m,
dim_t n,
dim_t k,
scomplex* alpha,
scomplex* a, const inc_t lda,
scomplex* b, const inc_t ldb,
scomplex* beta,
scomplex* c, const inc_t ldc
);
// AOCL-BLAS has a set of kernels(AVX2 and AVX512) that separately handle
// k=1 cases for ZGEMM. Thus, we need to define a test-fixture class for testing
// these kernels
class cgemmUkrk1 :
public ::testing::TestWithParam<std::tuple<scomplex, // alpha
scomplex, // beta
char, // storage
gtint_t, // m
gtint_t, // n
cgemm_k1_kernel, // kernel-pointer type
bool>> {}; // is_mem_test
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmUkrk1);
TEST_P(cgemmUkrk1, FunctionalTest)
{
using T = scomplex;
gtint_t k = 1;
T alpha = std::get<0>(GetParam()); // alpha
T beta = std::get<1>(GetParam()); // beta
char storage = std::get<2>(GetParam()); // indicates storage of all matrix operands
gtint_t m = std::get<3>(GetParam()); // m
gtint_t n = std::get<4>(GetParam()); // n
cgemm_k1_kernel kern_ptr = std::get<5>(GetParam()); // kernel address
bool memory_test = std::get<6>(GetParam()); // is_mem_test
// Call to the testing interface(specific to k=1 cases)
test_gemmk1_ukr(kern_ptr, m, n, k, storage, alpha, beta, memory_test);
}
class cgemmUkrk1Print {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<scomplex, scomplex, char, gtint_t, gtint_t, cgemm_k1_kernel, bool>> str) const {
gtint_t k = 1;
scomplex alpha = std::get<0>(str.param);
scomplex beta = std::get<1>(str.param);
char storage = std::get<2>(str.param);
gtint_t m = std::get<3>(str.param);
gtint_t n = std::get<4>(str.param);
bool memory_test = std::get<6>(str.param);
std::string str_name;
str_name += "_k_" + std::to_string(k);
str_name += "_alpha_" + testinghelpers::get_value_string(alpha);
str_name += "_beta_" + testinghelpers::get_value_string(beta);
str_name += "_m_" + std::to_string(m);
str_name += "_n_" + std::to_string(n);
str_name = str_name + "_" + storage;
str_name += ( memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled";
return str_name;
}
};
#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512)
#ifdef K_bli_cgemm_32x4_avx512_k1_nn
INSTANTIATE_TEST_SUITE_P(
bli_cgemm_32x4_avx512_k1_nn,
cgemmUkrk1,
::testing::Combine(
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
scomplex{0.0, 0.0}, scomplex{2.1, -1.9}), // alpha value
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
scomplex{0.0, 0.0}, scomplex{2.1, -1.9}), // beta value
::testing::Values('c'), // storage
::testing::Range(gtint_t(1), gtint_t(65), 1), // values of m
::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n
::testing::Values(bli_cgemm_32x4_avx512_k1_nn),
::testing::Values(true, false) // memory test
),
::cgemmUkrk1Print()
);
#endif
#endif

View File

@@ -524,8 +524,10 @@ static void test_gemmk1_ukr( FT ukr_fp, gtint_t m, gtint_t n, gtint_t k, char st
beta == testinghelpers::ONE<T>()))
thresh = 0.0;
else
thresh = (7*k+3)*testinghelpers::getEpsilon<T>();
{
double adj = 1.6;
thresh = adj*(7*k+3)*testinghelpers::getEpsilon<T>();
}
// call reference implementation
testinghelpers::ref_gemm<T>( storage, 'n', 'n', m, n, k, alpha,
buf_a, lda, buf_b, ldb, beta, buf_cref, ldc);

File diff suppressed because it is too large Load Diff

View File

@@ -420,6 +420,19 @@ void bli_dnorm2fv_unb_var1_avx512
cntx_t* cntx
);
void bli_cgemm_32x4_avx512_k1_nn
(
dim_t m,
dim_t n,
dim_t k,
scomplex* alpha,
scomplex* a, const inc_t lda,
scomplex* b, const inc_t ldb,
scomplex* beta,
scomplex* c, const inc_t ldc
);
err_t bli_zgemm_16x4_avx512_k1_nn
(
dim_t m,