mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Thread factorization improvements (ic ways) for BF16 LPGEMM API.
-Currently when m is small compared to n, even if MR blks (m / MR) > 1, and total work blocks (MR blks * NR blks) < available threads, the number of threads assigned for m dimension (ic ways) is 1. This results in sub par performance in bandwidth bound cases. To address this, the thread factorization is updated to increase ic ways for these cases. AMD-Internal: [SWLCSG-3333] Change-Id: Ife3eafc282a2b62eb212af615edb7afa40d09ae9
This commit is contained in:
committed by
MithunMohan KadavilMadanaMohanan
parent
ea93d2e2c9
commit
0701a4388a
@@ -424,6 +424,10 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
dim_t mrxnr_blks = mr_blks * nr_blks;
|
||||
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
|
||||
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
|
||||
const dim_t low_freq_thres = 6;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
@@ -437,6 +441,14 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( ( ( n % NR ) == 0 ) &&
|
||||
( mrxnr_blks <= ( *n_threads ) ) &&
|
||||
( delta_mr_blks_adj < low_freq_thres ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
@@ -644,6 +656,10 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
dim_t mrxnr_blks = mr_blks * nr_blks;
|
||||
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
|
||||
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
|
||||
const dim_t low_freq_thres = 6;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
@@ -657,6 +673,14 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( ( ( n % NR ) == 0 ) &&
|
||||
( mrxnr_blks <= ( *n_threads ) ) &&
|
||||
( delta_mr_blks_adj < low_freq_thres ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
@@ -789,6 +813,10 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
{
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
dim_t mrxnr_blks = mr_blks * nr_blks;
|
||||
dim_t mr_blks_adj_n_threads = ( ( *n_threads ) / mr_blks ) * mr_blks;
|
||||
dim_t delta_mr_blks_adj = ( *n_threads ) - mr_blks_adj_n_threads;
|
||||
const dim_t low_freq_thres = 6;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
@@ -802,6 +830,14 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( ( ( n % NR ) == 0 ) &&
|
||||
( mrxnr_blks <= ( *n_threads ) ) &&
|
||||
( delta_mr_blks_adj < low_freq_thres ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
( *jc_ways ) = ( *n_threads ) / ( *ic_ways );
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
@@ -957,7 +993,7 @@ BLIS_INLINE AOCL_TID_DISTR_TYPE lpgemm_get_tid_distr_type
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
dim_t mr_x_nr_blks = mr_blks * nr_blks;
|
||||
|
||||
dim_t low_util_n_thread_thres = ( 2 * n_threads ) / 3;
|
||||
dim_t low_util_n_thread_thres = n_threads / 2;
|
||||
dim_t mr_x_nr_blks_fringe = mr_x_nr_blks % n_threads;
|
||||
|
||||
lpgemm_thread_attrs_t* thr_attrs = lpgemm_get_thread_attrs();
|
||||
|
||||
@@ -1563,7 +1563,7 @@ int main( int argc, char** argv )
|
||||
dim_t stride_a, stride_b, stride_c;
|
||||
|
||||
const dim_t len_list_omp_cores_for_testing = 2;
|
||||
const dim_t list_omp_cores_for_testing[2] = { 1, 64 };
|
||||
const dim_t list_omp_cores_for_testing[2] = { 128, 1 };
|
||||
|
||||
dim_t core_index = 0;
|
||||
bool can_run = TRUE;
|
||||
|
||||
Reference in New Issue
Block a user