Improved thread balancing for aocl_gemm f32 API

Description:
1. Updated the thread partition logic for aocl_gemm_f32f32f32of32
   for m<MR, n<NR cases and also balanced thread in m, n directions
   such that each thread gets equal amount of work and not to span
   thread without any work.
2. Disabled dynamic enabling of packing of a and b matrixes for
   smaller sizes for genoa architecture.

AMD-Internal: [SWLCSG-2353 , SWLCSG-2391]
Change-Id: I03b2c50e592c2e9d336ea84c0e0394af63a34cec
This commit is contained in:
Bhaskar Nallani
2023-11-22 00:53:59 +05:30
committed by Nallani Bhaskar
parent 2676ac8249
commit 21d6ab6a21

View File

@@ -547,24 +547,53 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
if ( n <= NR )
{
// If n is less than micro panel dimension, allocating all threads
// to ic resulted in gains.
( *ic_ways ) = ( *n_threads );
( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads );
( *jc_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( m <= MR )
{
( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads );
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
(
MR, NR, m, n,
n_threads, ic_ways, jc_ways
);
if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) )
{
( *ic_ways ) = mr_blks;
( *jc_ways ) = nr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( mr_blks < ( *ic_ways ) )
{
( *ic_ways ) = mr_blks;
dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) );
( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( nr_blks < ( *jc_ways ) )
{
( *jc_ways ) = nr_blks;
dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) );
( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
(
MR, NR, m, n,
n_threads, ic_ways, jc_ways
);
}
}
}
else
@@ -624,15 +653,55 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
}
else if ( ( *n_threads ) > 1 )
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
lpgemm_adjust_ic_jc_ways
(
m, n, k,
MC, NC, KC, MR, NR,
n_threads, ic_ways, jc_ways, 5
);
if ( n <= NR )
{
( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads );
( *jc_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( m <= MR )
{
( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads );
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) )
{
( *ic_ways ) = mr_blks;
( *jc_ways ) = nr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( mr_blks < ( *ic_ways ) )
{
( *ic_ways ) = mr_blks;
dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) );
( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( nr_blks < ( *jc_ways ) )
{
( *jc_ways ) = nr_blks;
dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) );
( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
lpgemm_adjust_ic_jc_ways
(
m, n, k,
MC, NC, KC, MR, NR,
n_threads, ic_ways, jc_ways, 5
);
}
}
}
else
{
@@ -652,9 +721,8 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
{
if ( ( k > page_size_b_floatx2 ) ||
( ( k <= page_size_b_floatx2 ) &&
( m_ic > MT_2 ) && ( n_jc >= NT ) ) )
if (((k <= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
{
bli_rntm_set_pack_b( 1, rntm_g );
bli_rntm_set_pack_a( 1, rntm_g );