mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
AVX512 optimizations for CGEMM(rank-1 kernel)
- Implemented an AVX512 rank-1 kernel that is expected to handle column-major storage schemes of A, B and C(without transposition) when k = 1. - This kernel is single-threaded, and acts as a direct call from the BLAS layer for its compatible inputs. - Defined custom BLAS and BLIS_IMPLI layers for CGEMM (instead of using the macro definition), in order to integrate the call to this kernel at runtime(based on the corresponding architecture and input constraints). - Added unit-tests for functional and memory testing of the kernel. - Updated the ZEN5 context to include the AVX512 CGEMM SUP kernels, with its cache-blocking parameters. AMD-Internal: [CPUPL-6498] Change-Id: I42a66c424325bd117ceb38970726a05e2896a46b
This commit is contained in:
@@ -319,12 +319,12 @@ void bli_cntx_init_zen5( cntx_t* cntx )
|
||||
BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE,
|
||||
BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE,
|
||||
|
||||
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
|
||||
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
|
||||
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
|
||||
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
|
||||
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
|
||||
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
|
||||
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_cv_zen4_asm_24x4m, FALSE,
|
||||
|
||||
BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE,
|
||||
BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cd_zen4_asm_12x4m, FALSE,
|
||||
@@ -340,11 +340,11 @@ void bli_cntx_init_zen5( cntx_t* cntx )
|
||||
// Initialize level-3 sup blocksize objects with architecture-specific
|
||||
// values.
|
||||
// s d c z
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 3, 12 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 384, 128, 64 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4032, 2040, 1020 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 24, 12 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 4, 4 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 120, 48 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 384, 512, 64 );
|
||||
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4032, 4080, 1020 );
|
||||
|
||||
// Update the context with the current architecture's register and cache
|
||||
// blocksizes for small/unpacked level-3 problems.
|
||||
|
||||
@@ -1435,7 +1435,354 @@ void zgemm_
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
INSERT_GENTFUNC_BLAS_SC( gemm, gemm )
|
||||
|
||||
void cgemm_blis_impl
|
||||
(
|
||||
const f77_char* transa,
|
||||
const f77_char* transb,
|
||||
const f77_int* m,
|
||||
const f77_int* n,
|
||||
const f77_int* k,
|
||||
const scomplex* alpha,
|
||||
const scomplex* a, const f77_int* lda,
|
||||
const scomplex* b, const f77_int* ldb,
|
||||
const scomplex* beta,
|
||||
scomplex* c, const f77_int* ldc
|
||||
)
|
||||
{
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
dim_t m0, n0, k0;
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *transa, *transb, *m, *n, *k,
|
||||
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
PASTEBLACHK(gemm)
|
||||
(
|
||||
MKSTR(c),
|
||||
MKSTR(gemm),
|
||||
transa,
|
||||
transb,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
lda,
|
||||
ldb,
|
||||
ldc
|
||||
);
|
||||
|
||||
/* Quick return if possible. */
|
||||
if ( *m == 0 || *n == 0 || (( PASTEMAC(c,eq0)( *alpha ) || *k == 0)
|
||||
&& PASTEMAC(c,eq1)( *beta ) ))
|
||||
{
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
/* If alpha is zero scale C by beta and return early. */
|
||||
if( PASTEMAC(c,eq0)( *alpha ))
|
||||
{
|
||||
bli_convert_blas_dim1(*m, m0);
|
||||
bli_convert_blas_dim1(*n, n0);
|
||||
const inc_t rs_c = 1;
|
||||
const inc_t cs_c = *ldc;
|
||||
|
||||
PASTEMAC2(c,scalm,_ex)( BLIS_NO_CONJUGATE,
|
||||
0,
|
||||
BLIS_NONUNIT_DIAG,
|
||||
BLIS_DENSE,
|
||||
m0,
|
||||
n0,
|
||||
(scomplex*) beta,
|
||||
(scomplex*) c, rs_c, cs_c,
|
||||
NULL, NULL
|
||||
);
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
|
||||
|
||||
/* Typecast BLAS integers to BLIS integers. */
|
||||
bli_convert_blas_dim1( *m, m0 );
|
||||
bli_convert_blas_dim1( *n, n0 );
|
||||
bli_convert_blas_dim1( *k, k0 );
|
||||
|
||||
/* Set the row and column strides of the matrix operands. */
|
||||
const inc_t rs_a = 1;
|
||||
const inc_t cs_a = *lda;
|
||||
const inc_t rs_b = 1;
|
||||
const inc_t cs_b = *ldb;
|
||||
const inc_t rs_c = 1;
|
||||
const inc_t cs_c = *ldc;
|
||||
|
||||
/* Call GEMV when m == 1 or n == 1 with the context set
|
||||
to an uninitialized void pointer i.e. ((void *)0)*/
|
||||
if (n0 == 1)
|
||||
{
|
||||
if (bli_is_notrans(blis_transa))
|
||||
{
|
||||
bli_cgemv_unf_var2
|
||||
(
|
||||
blis_transa,
|
||||
bli_extract_conj(blis_transb),
|
||||
m0, k0,
|
||||
(scomplex *)alpha,
|
||||
(scomplex *)a, rs_a, cs_a,
|
||||
(scomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(scomplex *)beta,
|
||||
c, rs_c,
|
||||
((void *)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_cgemv_unf_var1
|
||||
(
|
||||
blis_transa,
|
||||
bli_extract_conj(blis_transb),
|
||||
k0, m0,
|
||||
(scomplex *)alpha,
|
||||
(scomplex *)a, rs_a, cs_a,
|
||||
(scomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(scomplex *)beta,
|
||||
c, rs_c,
|
||||
((void *)0)
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
else if (m0 == 1)
|
||||
{
|
||||
if (bli_is_notrans(blis_transb))
|
||||
{
|
||||
bli_cgemv_unf_var1
|
||||
(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
n0, k0,
|
||||
(scomplex *)alpha,
|
||||
(scomplex *)b, cs_b, rs_b,
|
||||
(scomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(scomplex *)beta,
|
||||
c, cs_c,
|
||||
((void *)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_cgemv_unf_var2
|
||||
(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
k0, n0,
|
||||
(scomplex *)alpha,
|
||||
(scomplex *)b, cs_b, rs_b,
|
||||
(scomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(scomplex *)beta,
|
||||
c, cs_c,
|
||||
((void *)0)
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
// This function is invoked on all architectures including 'generic'.
|
||||
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
|
||||
if (bli_cpuid_is_avx2fma3_supported() == FALSE)
|
||||
{
|
||||
// This code is duplicated below, however we don't want to move it out of
|
||||
// this IF block as we want to avoid object initialization until required.
|
||||
// Also this is temporary fix which will be replaced later.
|
||||
const num_t dt = BLIS_SCOMPLEX;
|
||||
|
||||
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
|
||||
obj_t ao = BLIS_OBJECT_INITIALIZER;
|
||||
obj_t bo = BLIS_OBJECT_INITIALIZER;
|
||||
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
|
||||
obj_t co = BLIS_OBJECT_INITIALIZER;
|
||||
|
||||
dim_t m0_a, n0_a;
|
||||
dim_t m0_b, n0_b;
|
||||
|
||||
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
|
||||
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
|
||||
|
||||
bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao );
|
||||
bli_obj_init_finish_1x1( dt, (scomplex*)beta, &betao );
|
||||
|
||||
bli_obj_init_finish( dt, m0_a, n0_a, (scomplex*)a, rs_a, cs_a, &ao );
|
||||
bli_obj_init_finish( dt, m0_b, n0_b, (scomplex*)b, rs_b, cs_b, &bo );
|
||||
bli_obj_init_finish( dt, m0, n0, (scomplex*)c, rs_c, cs_c, &co );
|
||||
|
||||
bli_obj_set_conjtrans( blis_transa, &ao );
|
||||
bli_obj_set_conjtrans( blis_transb, &bo );
|
||||
|
||||
// Will call parallelized cgemm code - sup & native
|
||||
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
|
||||
(
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
&betao,
|
||||
&co,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
Invoking the API for input sizes with k = 1.
|
||||
- The API is single-threaded.
|
||||
- The input constraints are that k should be 1, and transa and transb
|
||||
should be N and N respectively.
|
||||
*/
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
if( ( k0 == 1 ) && bli_is_notrans( blis_transa ) && bli_is_notrans( blis_transb ) )
|
||||
{
|
||||
arch_t arch_id = bli_arch_query_id();
|
||||
|
||||
if ( ( arch_id == BLIS_ARCH_ZEN4 ) || ( arch_id == BLIS_ARCH_ZEN5 ) )
|
||||
{
|
||||
bli_cgemm_32x4_avx512_k1_nn
|
||||
(
|
||||
m0, n0, k0,
|
||||
(scomplex*)alpha,
|
||||
(scomplex*)a, *lda,
|
||||
(scomplex*)b, *ldb,
|
||||
(scomplex*)beta,
|
||||
c, *ldc
|
||||
);
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const num_t dt = BLIS_SCOMPLEX;
|
||||
|
||||
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
|
||||
obj_t ao = BLIS_OBJECT_INITIALIZER;
|
||||
obj_t bo = BLIS_OBJECT_INITIALIZER;
|
||||
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
|
||||
obj_t co = BLIS_OBJECT_INITIALIZER;
|
||||
|
||||
dim_t m0_a, n0_a;
|
||||
dim_t m0_b, n0_b;
|
||||
|
||||
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
|
||||
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
|
||||
|
||||
bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao );
|
||||
bli_obj_init_finish_1x1( dt, (scomplex*)beta, &betao );
|
||||
|
||||
bli_obj_init_finish( dt, m0_a, n0_a, (scomplex*)a, rs_a, cs_a, &ao );
|
||||
bli_obj_init_finish( dt, m0_b, n0_b, (scomplex*)b, rs_b, cs_b, &bo );
|
||||
bli_obj_init_finish( dt, m0, n0, (scomplex*)c, rs_c, cs_c, &co );
|
||||
|
||||
bli_obj_set_conjtrans( blis_transa, &ao );
|
||||
bli_obj_set_conjtrans( blis_transb, &bo );
|
||||
|
||||
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
|
||||
if (status == BLIS_SUCCESS)
|
||||
{
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
// fall back on native path when cgemm is not handled in sup path.
|
||||
//bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
|
||||
|
||||
/* Default to using native execution. */
|
||||
ind_t im = BLIS_NAT;
|
||||
|
||||
/* As each matrix operand has a complex storage datatype, try to get an
|
||||
induced method (if one is available and enabled). NOTE: Allowing
|
||||
precisions to vary while using 1m, which is what we do here, is unique
|
||||
to gemm; other level-3 operations use 1m only if all storage datatypes
|
||||
are equal (and they ignore the computation precision). */
|
||||
|
||||
/* Find the highest priority induced method that is both enabled and
|
||||
available for the current operation. (If an induced method is
|
||||
available but not enabled, or simply unavailable, BLIS_NAT will
|
||||
be returned here.) */
|
||||
im = bli_gemmind_find_avail( dt );
|
||||
|
||||
/* Obtain a valid context from the gks using the induced
|
||||
method id determined above. */
|
||||
cntx_t* cntx = bli_gks_query_ind_cntx( im, dt );
|
||||
|
||||
rntm_t rntm_l;
|
||||
bli_rntm_init_from_global( &rntm_l );
|
||||
|
||||
/* Invoke the operation's front-end and request the default control tree. */
|
||||
PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL );
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
}// end of cgemm_
|
||||
#ifdef BLIS_ENABLE_BLAS
|
||||
void cgemm_
|
||||
(
|
||||
const f77_char* transa,
|
||||
const f77_char* transb,
|
||||
const f77_int* m,
|
||||
const f77_int* n,
|
||||
const f77_int* k,
|
||||
const scomplex* alpha,
|
||||
const scomplex* a, const f77_int* lda,
|
||||
const scomplex* b, const f77_int* ldb,
|
||||
const scomplex* beta,
|
||||
scomplex* c, const f77_int* ldc
|
||||
)
|
||||
{
|
||||
cgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
arch_t id = bli_arch_query_id();
|
||||
if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4)
|
||||
{
|
||||
bli_zero_zmm();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
INSERT_GENTFUNC_BLAS_S( gemm, gemm )
|
||||
|
||||
void dzgemm_blis_impl
|
||||
(
|
||||
|
||||
@@ -71,6 +71,10 @@ GENTFUNC( dcomplex, z, blasname, blisname )
|
||||
\
|
||||
GENTFUNC( scomplex, c, blasname, blisname )
|
||||
|
||||
#define INSERT_GENTFUNC_BLAS_S( blasname, blisname ) \
|
||||
\
|
||||
GENTFUNC( float, s, blasname, blisname )
|
||||
|
||||
// -- Basic one-operand macro with real domain only --
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
|
||||
#ifdef AOCL_DEV
|
||||
|
||||
#define K_bli_cgemm_32x4_avx512_k1_nn 1
|
||||
#define K_bli_cgemmsup_cv_zen4_asm_24x4m 1
|
||||
#define K_bli_cgemmsup_cv_zen4_asm_24x3m 1
|
||||
#define K_bli_cgemmsup_cv_zen4_asm_24x2m 1
|
||||
|
||||
@@ -99,8 +99,10 @@ TEST_P( cgemmGeneric, API )
|
||||
thresh = adj*testinghelpers::getEpsilon<T>();
|
||||
}
|
||||
else
|
||||
thresh = (3*k+1)*testinghelpers::getEpsilon<T>();
|
||||
|
||||
{
|
||||
double adj = 4.0;
|
||||
thresh = adj*(3*k+1)*testinghelpers::getEpsilon<T>();
|
||||
}
|
||||
//----------------------------------------------------------
|
||||
// Call test body using these parameters
|
||||
//----------------------------------------------------------
|
||||
@@ -284,3 +286,30 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
),
|
||||
::gemmGenericPrint<scomplex>()
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
K_1,
|
||||
cgemmGeneric,
|
||||
::testing::Combine(
|
||||
::testing::Values('c'
|
||||
#ifndef TEST_BLAS_LIKE
|
||||
,'r'
|
||||
#endif
|
||||
), // storage format
|
||||
::testing::Values('n'), // transa
|
||||
::testing::Values('n'), // transb
|
||||
::testing::Range(gtint_t(2), gtint_t(63), 1), // m
|
||||
::testing::Range(gtint_t(2), gtint_t(9), 1), // n
|
||||
::testing::Values(gtint_t(1)), // k
|
||||
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
|
||||
scomplex{0.0, 1.0}, scomplex{2.1, -1.9},
|
||||
scomplex{0.0, 0.0}), // alpha
|
||||
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
|
||||
scomplex{0.0, 1.0}, scomplex{2.1, -1.9},
|
||||
scomplex{0.0, 0.0}), // beta
|
||||
::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a
|
||||
::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of b
|
||||
::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c
|
||||
),
|
||||
::gemmGenericPrint<scomplex>()
|
||||
);
|
||||
@@ -1572,3 +1572,94 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Function pointer specific to cgemm kernel that handles
|
||||
// special case where k=1.
|
||||
typedef void (*cgemm_k1_kernel)
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* alpha,
|
||||
scomplex* a, const inc_t lda,
|
||||
scomplex* b, const inc_t ldb,
|
||||
scomplex* beta,
|
||||
scomplex* c, const inc_t ldc
|
||||
);
|
||||
|
||||
// AOCL-BLAS has a set of kernels(AVX2 and AVX512) that separately handle
|
||||
// k=1 cases for ZGEMM. Thus, we need to define a test-fixture class for testing
|
||||
// these kernels
|
||||
class cgemmUkrk1 :
|
||||
public ::testing::TestWithParam<std::tuple<scomplex, // alpha
|
||||
scomplex, // beta
|
||||
char, // storage
|
||||
gtint_t, // m
|
||||
gtint_t, // n
|
||||
cgemm_k1_kernel, // kernel-pointer type
|
||||
bool>> {}; // is_mem_test
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmUkrk1);
|
||||
|
||||
TEST_P(cgemmUkrk1, FunctionalTest)
|
||||
{
|
||||
using T = scomplex;
|
||||
gtint_t k = 1;
|
||||
T alpha = std::get<0>(GetParam()); // alpha
|
||||
T beta = std::get<1>(GetParam()); // beta
|
||||
char storage = std::get<2>(GetParam()); // indicates storage of all matrix operands
|
||||
gtint_t m = std::get<3>(GetParam()); // m
|
||||
gtint_t n = std::get<4>(GetParam()); // n
|
||||
cgemm_k1_kernel kern_ptr = std::get<5>(GetParam()); // kernel address
|
||||
bool memory_test = std::get<6>(GetParam()); // is_mem_test
|
||||
|
||||
// Call to the testing interface(specific to k=1 cases)
|
||||
test_gemmk1_ukr(kern_ptr, m, n, k, storage, alpha, beta, memory_test);
|
||||
}
|
||||
|
||||
class cgemmUkrk1Print {
|
||||
public:
|
||||
std::string operator()(
|
||||
testing::TestParamInfo<std::tuple<scomplex, scomplex, char, gtint_t, gtint_t, cgemm_k1_kernel, bool>> str) const {
|
||||
gtint_t k = 1;
|
||||
scomplex alpha = std::get<0>(str.param);
|
||||
scomplex beta = std::get<1>(str.param);
|
||||
char storage = std::get<2>(str.param);
|
||||
gtint_t m = std::get<3>(str.param);
|
||||
gtint_t n = std::get<4>(str.param);
|
||||
bool memory_test = std::get<6>(str.param);
|
||||
|
||||
std::string str_name;
|
||||
str_name += "_k_" + std::to_string(k);
|
||||
str_name += "_alpha_" + testinghelpers::get_value_string(alpha);
|
||||
str_name += "_beta_" + testinghelpers::get_value_string(beta);
|
||||
str_name += "_m_" + std::to_string(m);
|
||||
str_name += "_n_" + std::to_string(n);
|
||||
str_name = str_name + "_" + storage;
|
||||
str_name += ( memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled";
|
||||
|
||||
return str_name;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512)
|
||||
#ifdef K_bli_cgemm_32x4_avx512_k1_nn
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
bli_cgemm_32x4_avx512_k1_nn,
|
||||
cgemmUkrk1,
|
||||
::testing::Combine(
|
||||
|
||||
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
|
||||
scomplex{0.0, 0.0}, scomplex{2.1, -1.9}), // alpha value
|
||||
::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0},
|
||||
scomplex{0.0, 0.0}, scomplex{2.1, -1.9}), // beta value
|
||||
::testing::Values('c'), // storage
|
||||
::testing::Range(gtint_t(1), gtint_t(65), 1), // values of m
|
||||
::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n
|
||||
::testing::Values(bli_cgemm_32x4_avx512_k1_nn),
|
||||
::testing::Values(true, false) // memory test
|
||||
),
|
||||
::cgemmUkrk1Print()
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -524,8 +524,10 @@ static void test_gemmk1_ukr( FT ukr_fp, gtint_t m, gtint_t n, gtint_t k, char st
|
||||
beta == testinghelpers::ONE<T>()))
|
||||
thresh = 0.0;
|
||||
else
|
||||
thresh = (7*k+3)*testinghelpers::getEpsilon<T>();
|
||||
|
||||
{
|
||||
double adj = 1.6;
|
||||
thresh = adj*(7*k+3)*testinghelpers::getEpsilon<T>();
|
||||
}
|
||||
// call reference implementation
|
||||
testinghelpers::ref_gemm<T>( storage, 'n', 'n', m, n, k, alpha,
|
||||
buf_a, lda, buf_b, ldb, beta, buf_cref, ldc);
|
||||
|
||||
2019
kernels/zen4/3/bli_cgemm_avx512_k1.c
Normal file
2019
kernels/zen4/3/bli_cgemm_avx512_k1.c
Normal file
File diff suppressed because it is too large
Load Diff
@@ -420,6 +420,19 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
void bli_cgemm_32x4_avx512_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* alpha,
|
||||
scomplex* a, const inc_t lda,
|
||||
scomplex* b, const inc_t ldb,
|
||||
scomplex* beta,
|
||||
scomplex* c, const inc_t ldc
|
||||
);
|
||||
|
||||
|
||||
err_t bli_zgemm_16x4_avx512_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
|
||||
Reference in New Issue
Block a user