Additional optimizations to ZGEMM SUP and Tiny codepaths(ZEN4 and ZEN5)

- Added a set of AVX512 fringe kernels(using masked loads and
  stores) in order to avoid rerouting to the GEMV typed API
  interface(when m = 1). This ensures uniformity in performance
  across the main and fringe cases, when the calls are multithreaded.

- Further tuned the thresholds to decide between ZGEMM Tiny, Small
  SUP and Native paths for ZEN4 and ZEN5 architectures(in case
  of parallel execution). This would account for additional
  combinations of the input dimensions.

- Moved the call to Tiny-ZGEMM before the BLIS object creation,
  since this code-path operates on raw buffers.

- Added the necessary test-cases for functional and memory testing
  of the newly added kernels.

AMD-Internal: [CPUPL-6378][CPUPL-6661]
Change-Id: I9af73d1b6ef82b26503d4fc373111132aee3afd6
This commit is contained in:
Vignesh Balasubramanian
2025-04-07 14:26:01 +05:30
committed by Vignesh Balasubramanian
parent 87c9230cac
commit b4b0887ca4
9 changed files with 1910 additions and 81 deletions

View File

@@ -46,7 +46,12 @@
( ( bli_is_notrans( transa ) && ( m < 60 ) && ( n >= 4 ) && ( n < 200 ) && ( k < 68 ) ) || \
( bli_is_trans( transa ) && ( m < 200 ) && ( n < 200 ) && ( k < 200 ) && ( k >= 16 ) ) ) ) || \
/* In case of multi-threaded request */ \
( ( is_parallel ) && ( ( m * n * k ) < 12500 ) )
( ( is_parallel ) && \
/* Separate thresholds based on transpose value of A */ \
( ( bli_is_notrans( transa ) && \
( ( ( m <= 6 ) && ( n <= 80 ) && ( k <= 64 ) ) ) ) || \
( bli_is_trans( transa ) && \
( ( ( m <= 6 ) && ( n <= 40 ) && ( k <= 72 ) ) || ( ( m <= 12 ) && ( n <= 24 ) && ( k <= 44 ) ) ) ) ) )
#define zgemm_tiny_zen4_thresh_avx512( transa, transb, m, n, k, is_parallel ) \
/* In case of single-threaded request */ \
@@ -56,7 +61,12 @@
( ( m * k ) < 1500 ) && ( ( n * k ) < 1500 ) && ( ( m * n ) < 1500 ) ) || \
( bli_is_trans( transa ) && ( m < 200 ) && ( n < 200 ) && ( k < 200 ) && ( k >= 8 ) ) ) ) || \
/* In case of multi-threaded request */ \
( ( is_parallel ) && ( ( m * n * k ) < 15000 ) )
( ( is_parallel ) && \
/* Separate thresholds based on transpose value of A */ \
( ( bli_is_notrans( transa ) && \
( ( n <= 8 ) && ( ( ( m <= 32 ) && ( k <= 80 ) ) || ( ( m <= 80 ) && ( k <= 20 ) ) ) ) ) || \
( bli_is_notrans( transa ) && \
( ( n <= 8 ) && ( ( ( m <= 40 ) && ( k <= 40 ) ) || ( ( m <= 96 ) && ( k <= 12 ) ) || ( ( m <= 16 ) && ( k <= 96 ) ) ) ) ) ) )
/* Defining the macro to be used for selecting the kernel at runtime */
#define ZEN4_UKR_SELECTOR( ch, transa, transb, m, n, k, stor_id, ukr_support, gemmtiny_ukr_info, is_parallel ) \

View File

@@ -46,7 +46,12 @@
( ( bli_is_notrans( transa ) && ( m < 8 ) && ( n < 200 ) && ( k < 200 ) ) || \
( bli_is_trans( transa ) && ( m < 8 ) && ( n < 200 ) && ( k < 200 ) && ( k >= 8 ) ) ) ) || \
/* In case of multi-threaded request */ \
( ( is_parallel ) && ( ( m * n * k ) < 5000 ) && ( k >= 16 ) )
( ( is_parallel ) && \
/* Separate thresholds based on transpose value of A */ \
( ( bli_is_notrans( transa ) && \
( ( m <= 4 ) && ( n <= 200 ) && ( k <= 8 ) ) ) || \
( bli_is_trans( transa ) && \
( ( m <= 4 ) && ( n >= 12 ) && ( n <= 200 ) && ( k <= 4 ) ) ) ) )
#define zgemm_tiny_zen5_thresh_avx512( transa, transb, m, n, k, is_parallel ) \
/* In case of single-threaded request */ \
@@ -55,7 +60,14 @@
( ( bli_is_notrans( transa ) && ( m < 200 ) && ( n < 200 ) && ( k < 200 ) ) || \
( bli_is_trans( transa ) && ( m < 200 ) && ( n < 200 ) && ( k < 200 ) && ( k >= 8 ) ) ) ) || \
/* In case of multi-threaded request */ \
( ( is_parallel ) && ( ( m * n * k ) < 10000 ) && ( k >= 16 ) )
( ( is_parallel ) && \
/* Separate thresholds based on transpose value of A */ \
( ( bli_is_notrans( transa ) && \
( ( m <= 200 ) && ( k <= 200 ) && ( ( ( n <= 16 ) && ( ( m * k ) <= 16000 ) ) || \
( ( n <= 16 ) && ( ( m * k ) <= 13000 ) ) ) ) ) || \
( bli_is_notrans( transa ) && \
( ( m <= 200 ) && ( k <= 200 ) && ( ( ( n <= 16 ) && ( ( m * k ) <= 7000 ) ) || \
( ( n <= 16 ) && ( ( m * k ) <= 6000 ) ) ) ) ) ) )
/* Defining the macro to be used for selecting the kernel at runtime */
#define ZEN5_UKR_SELECTOR( ch, transa, transb, m, n, k, stor_id, ukr_support, gemmtiny_ukr_info, is_parallel ) \

View File

@@ -1218,30 +1218,6 @@ void zgemm_blis_impl
}
}
const num_t dt = BLIS_DCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
bool is_parallel = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked.
// Tiny gemm dispatch
@@ -1271,6 +1247,30 @@ void zgemm_blis_impl
}
#endif
const num_t dt = BLIS_DCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
#ifdef BLIS_ENABLE_SMALL_MATRIX
/* Querying the acrhitecture ID at runtime to choose the code-path based on the micro-arch */
@@ -1301,9 +1301,8 @@ void zgemm_blis_impl
double overall_thresh = (double)m0 * (double)n0 * (double)k0;
bool mat_based_thresh = (( a_thresh < 500 ) || ( b_thresh < 500 ) || ( c_thresh < 500 ));
bool entry_to_small_st = (( !is_parallel ) && mat_based_thresh && ( overall_thresh < 7500 ));
bool entry_to_small_mt = (( is_parallel ) && mat_based_thresh && ( overall_thresh < 5000 ));
entry_to_small = entry_to_small_st || entry_to_small_mt;
entry_to_small = entry_to_small_st;
break;
}
case BLIS_ARCH_ZEN4:
@@ -1313,7 +1312,7 @@ void zgemm_blis_impl
double overall_thresh = (double)m0 * (double)n0 * (double)k0;
bool mat_based_thresh = (( a_thresh < 600 ) || ( b_thresh < 600 ) || ( c_thresh < 600 ));
bool entry_to_small_st = (( !is_parallel ) && mat_based_thresh && ( overall_thresh < 20000 ));
bool entry_to_small_mt = (( is_parallel ) && mat_based_thresh && ( overall_thresh < 12500 ));
bool entry_to_small_mt = (( is_parallel ) && bli_is_trans( blis_transa ) && ( k0 <= 24 ) && ((( m0 <= 8 ) && ( n0 <= 60 )) || (( m0 <= 40 ) && ( n0 <= 12 ))));
entry_to_small = entry_to_small_st || entry_to_small_mt;
break;

View File

@@ -40,6 +40,10 @@
#ifdef AOCL_DEV
#define K_bli_zgemmsup_cv_zen4_asm_fx1 1
#define K_bli_zgemmsup_cv_zen4_asm_fx2 1
#define K_bli_zgemmsup_cv_zen4_asm_fx3 1
#define K_bli_zgemmsup_cv_zen4_asm_fx4 1
#define K_bli_cgemm_32x4_avx512_k1_nn 1
#define K_bli_cgemmsup_cv_zen4_asm_24x4m 1
#define K_bli_cgemmsup_cv_zen4_asm_24x3m 1

View File

@@ -975,6 +975,166 @@ INSTANTIATE_TEST_SUITE_P(
);
#endif
#ifdef K_bli_zgemmsup_cv_zen4_asm_fx4
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx4_col_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(4)), // values of n
::testing::Range(gtint_t(0), gtint_t(9), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 3}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('c'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx4), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('n'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx4_row_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(4)), // values of n
::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('r'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx4), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('t'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
#endif
#ifdef K_bli_zgemmsup_cv_zen4_asm_fx3
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx3_col_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(3)), // values of n
::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('c'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx3), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('n'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx3_row_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(3)), // values of n
::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('r'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx3), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('t'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
#endif
#ifdef K_bli_zgemmsup_cv_zen4_asm_fx2
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx2_col_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(2)), // values of n
::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -19}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -9}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('c'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx2), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('n'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx2_row_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(2)), // values of n
::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('r'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx2), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('t'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
#endif
#ifdef K_bli_zgemmsup_cv_zen4_asm_fx1
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx1_col_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(1)), // values of n
::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -19}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('c'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx1), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('n'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_fx1_row_stored_c,
zgemmGenericSUP,
::testing::Combine(
::testing::Values(gtint_t(3)), // values of m
::testing::Values(gtint_t(1)), // values of n
::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value
::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value
::testing::Values('r'), // storage of c
::testing::Values(bli_zgemmsup_cv_zen4_asm_fx1), // zgemm_sup kernel
::testing::Values(gtint_t(4)), // Micro kernel block MR
::testing::Values('t'), // transa
::testing::Values('n', 't'), // transb
::testing::Values(false, true) // is_memory_test
),
::zgemmGenericSUPPrint()
);
#endif
#ifdef K_bli_zgemmsup_cv_zen4_asm_2x4
INSTANTIATE_TEST_SUITE_P(
bli_zgemmsup_cv_zen4_asm_2x4_col_stored_c,

File diff suppressed because it is too large Load Diff

View File

@@ -86,8 +86,10 @@ bool bli_cntx_gemmsup_thresh_is_met_zen4( obj_t* a, obj_t* b, obj_t* c, cntx_t*
// For skinny sizes where m and/or n is small
// The threshold for m is a single value, but for n, it is
// also based on the packing size of A, since the kernels are
// column preferential
if( ( m <= 84 ) || ( ( n <= 84 ) && ( m < 4000 ) ) ) return TRUE;
// column preferential
if( ( ( m <= 120 ) && ( n <= 7515 ) && ( k <= 128 ) ) ||
// ( ( m <= 96 ) && ( n <= 7515 ) && ( k <= 128 ) ) ||
( ( m <= 1200 ) && ( n <= 1200 ) && ( k <= 64 ) ) ) return TRUE;
// For all combinations in small sizes
if( ( m <= 216 ) && ( n <= 216 ) && ( k <= 216 ) ) return TRUE;

View File

@@ -361,6 +361,11 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_8x3 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_8x2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_8x1 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_fx4 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_fx3 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_fx2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_fx1 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_4x4 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_4x3 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_4x2 )

View File

@@ -87,7 +87,7 @@ bool bli_cntx_gemmsup_thresh_is_met_zen5( obj_t* a, obj_t* b, obj_t* c, cntx_t*
// The threshold for m is a single value, but for n, it is
// also based on the packing size of A, since the kernels are
// column preferential
if( ( m <= 84 ) || ( ( n <= 84 ) && ( ( m * k ) <= 983040 ) ) ) return TRUE;
if( ( m <= 60 ) || ( ( n <= 60 ) && ( m <= 960 ) && ( k <= 16384 ) ) || ( k <= 8 ) ) return TRUE;
// For all combinations in small sizes
if( ( m <= 216 ) && ( n <= 216 ) && ( k <= 216 ) ) return TRUE;