Thread factorization improvements (ic ways) for BF16 LPGEMM API.

-Currently when m is small compared to n, even if MR blks (m / MR) > 1,
and total work blocks (MR blks * NR blks) < available threads, the
number of threads assigned for m dimension (ic ways) is 1. This results
in sub par performance in bandwidth bound cases. To address this, the
thread factorization is updated to increase ic ways for these cases.

AMD-Internal: [SWLCSG-3333]

Change-Id: Ife3eafc282a2b62eb212af615edb7afa40d09ae9
This commit is contained in:
Mithun Mohan
2025-01-13 11:43:39 +00:00
committed by MithunMohan KadavilMadanaMohanan
parent ea93d2e2c9
commit 0701a4388a
2 changed files with 38 additions and 2 deletions

View File

@@ -424,6 +424,10 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
dim_t mrxnr_blks = mr_blks * nr_blks;
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
const dim_t low_freq_thres = 6;
if ( n <= NR )
{
@@ -437,6 +441,14 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( ( ( n % NR ) == 0 ) &&
( mrxnr_blks <= ( *n_threads ) ) &&
( delta_mr_blks_adj < low_freq_thres ) )
{
( *ic_ways ) = mr_blks;
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
@@ -644,6 +656,10 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
dim_t mrxnr_blks = mr_blks * nr_blks;
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
const dim_t low_freq_thres = 6;
if ( n <= NR )
{
@@ -657,6 +673,14 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( ( ( n % NR ) == 0 ) &&
( mrxnr_blks <= ( *n_threads ) ) &&
( delta_mr_blks_adj < low_freq_thres ) )
{
( *ic_ways ) = mr_blks;
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
@@ -789,6 +813,10 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
{
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
dim_t mrxnr_blks = mr_blks * nr_blks;
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
const dim_t low_freq_thres = 6;
if ( n <= NR )
{
@@ -802,6 +830,14 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( ( ( n % NR ) == 0 ) &&
( mrxnr_blks <= ( *n_threads ) ) &&
( delta_mr_blks_adj < low_freq_thres ) )
{
( *ic_ways ) = mr_blks;
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
@@ -957,7 +993,7 @@ BLIS_INLINE AOCL_TID_DISTR_TYPE lpgemm_get_tid_distr_type
dim_t nr_blks = ( n + NR - 1 ) / NR;
dim_t mr_x_nr_blks = mr_blks * nr_blks;
dim_t low_util_n_thread_thres = ( 2 * n_threads ) / 3;
dim_t low_util_n_thread_thres = n_threads / 2;
dim_t mr_x_nr_blks_fringe = mr_x_nr_blks % n_threads;
lpgemm_thread_attrs_t* thr_attrs = lpgemm_get_thread_attrs();

View File

@@ -1563,7 +1563,7 @@ int main( int argc, char** argv )
dim_t stride_a, stride_b, stride_c;
const dim_t len_list_omp_cores_for_testing = 2;
const dim_t list_omp_cores_for_testing[2] = { 1, 64 };
const dim_t list_omp_cores_for_testing[2] = { 128, 1 };
dim_t core_index = 0;
bool can_run = TRUE;