mirror of
https://github.com/amd/blis.git
synced 2026-05-04 22:41:11 +00:00
Improved thread balancing for aocl_gemm f32 API
Description: 1. Updated the thread partition logic for aocl_gemm_f32f32f32of32 for m<MR, n<NR cases and also balanced thread in m, n directions such that each thread gets equal amount of work and not to span thread without any work. 2. Disabled dynamic enabling of packing of a and b matrixes for smaller sizes for genoa architecture. AMD-Internal: [SWLCSG-2353 , SWLCSG-2391] Change-Id: I03b2c50e592c2e9d336ea84c0e0394af63a34cec
This commit is contained in:
committed by
Nallani Bhaskar
parent
2676ac8249
commit
21d6ab6a21
@@ -547,24 +547,53 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
// If n is less than micro panel dimension, allocating all threads
|
||||
// to ic resulted in gains.
|
||||
( *ic_ways ) = ( *n_threads );
|
||||
( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads );
|
||||
( *jc_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( m <= MR )
|
||||
{
|
||||
( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads );
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
|
||||
|
||||
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
|
||||
(
|
||||
MR, NR, m, n,
|
||||
n_threads, ic_ways, jc_ways
|
||||
);
|
||||
if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
( *jc_ways ) = nr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( mr_blks < ( *ic_ways ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) );
|
||||
( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( nr_blks < ( *jc_ways ) )
|
||||
{
|
||||
( *jc_ways ) = nr_blks;
|
||||
dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) );
|
||||
( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
|
||||
(
|
||||
MR, NR, m, n,
|
||||
n_threads, ic_ways, jc_ways
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -624,15 +653,55 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
}
|
||||
else if ( ( *n_threads ) > 1 )
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
|
||||
lpgemm_adjust_ic_jc_ways
|
||||
(
|
||||
m, n, k,
|
||||
MC, NC, KC, MR, NR,
|
||||
n_threads, ic_ways, jc_ways, 5
|
||||
);
|
||||
if ( n <= NR )
|
||||
{
|
||||
( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads );
|
||||
( *jc_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( m <= MR )
|
||||
{
|
||||
( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads );
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
|
||||
if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
( *jc_ways ) = nr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( mr_blks < ( *ic_ways ) )
|
||||
{
|
||||
( *ic_ways ) = mr_blks;
|
||||
dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) );
|
||||
( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( nr_blks < ( *jc_ways ) )
|
||||
{
|
||||
( *jc_ways ) = nr_blks;
|
||||
dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) );
|
||||
( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_adjust_ic_jc_ways
|
||||
(
|
||||
m, n, k,
|
||||
MC, NC, KC, MR, NR,
|
||||
n_threads, ic_ways, jc_ways, 5
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -652,9 +721,8 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
|
||||
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
|
||||
{
|
||||
if ( ( k > page_size_b_floatx2 ) ||
|
||||
( ( k <= page_size_b_floatx2 ) &&
|
||||
( m_ic > MT_2 ) && ( n_jc >= NT ) ) )
|
||||
if (((k <= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
|
||||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
|
||||
{
|
||||
bli_rntm_set_pack_b( 1, rntm_g );
|
||||
bli_rntm_set_pack_a( 1, rntm_g );
|
||||
|
||||
Reference in New Issue
Block a user