Threshold tuning for code-paths and optimal thread selection for ZGEMM(ZEN5)

- Updated the thresholds to enter the AVX512 SUP codepath in
  ZGEMM(on ZEN5). This caters to inputs that scale well with
  multithreaded-execution(in the SUP path).

- Also updated the thresholds to decide ideal threads, based on
  'm', 'n' and 'k' values. The thread-setting logic involves
  determining the number of tiles for computation, and using them
  to further tune for the optimal number of threads.

- This logic builds over the assumption that the current thread
  factorization logic is optimal. Thus, an additional data analysis
  was performed(on the existing ZEN4 and the new ZEN5 thresholds),
  to also cover the corner cases, where this assumption doesn't hold
  true.

- As part of the future work, we could reimplement the thread
  factorization for GEMM, which would additionally require a new
  set of threshold tuning for every datatype.

AMD-Internal: [CPUPL-7028]

Co-authored-by: Vignesh Balasubramanian <vignbala@amd.com>
This commit is contained in:
Balasubramanian, Vignesh
2025-08-01 16:02:12 +05:30
committed by GitHub
parent 1bb1160061
commit c96e7eb197
4 changed files with 468 additions and 237 deletions

View File

@@ -52,13 +52,22 @@ void bli_gemm_front
#ifdef AOCL_DYNAMIC
// If dynamic-threading is enabled, calculate optimum number
// of threads.
// rntm will be updated with optimum number of threads.
if( bli_obj_is_dcomplex(c))// This will enable for ZGEMM
{
bli_nthreads_optimum(a, b, c, BLIS_GEMM, rntm);
}
#endif
// of threads. rntm will be updated with optimum number of threads.
// In case of ZEN3/ZEN2/ZEN architecture, we are using the aocl_dynamic
// logic to decide the optimal number of threads.
// Ideally, the native path is intended to be used solely for compute
// bound sizes, without any need for dynamic threading.
// TODO : As part of future work, we have to retune the entry conditions
// to native(ZEN3/ZEN2/ZEN), and remove the need for dynamic threading
// here (GitHub Issue #114).
arch_t arch_id = bli_arch_query_id();
if( bli_obj_is_dcomplex( c ) && ( ( arch_id == BLIS_ARCH_ZEN3 ) ||
( arch_id == BLIS_ARCH_ZEN2 ) || ( arch_id == BLIS_ARCH_ZEN ) ) )
{
bli_nthreads_optimum(a, b, c, BLIS_GEMM, rntm);
}
#endif
obj_t a_local;
obj_t b_local;

View File

@@ -1644,7 +1644,7 @@ void bli_nthreads_optimum(
// Query the architecture ID
arch_t id = bli_arch_query_id();
if( id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4 )
if( id == BLIS_ARCH_ZEN5 )
{
/*
The logic for ideal thread selection is as follows:
@@ -1663,321 +1663,547 @@ void bli_nthreads_optimum(
computing per unit of data loaded, so you may be able to utilize more threads efficiently.
Also, we have a candidate set of discrete thread values from which we choose the
best number of threads. That is, from the list of [1, 2, 4, 8, 16, 32, 48, 64, 96, 192].
best number of threads. That is, from the list of [1, 2, 4, 8, 16, 32, 48, 64, 128, 256].
Thus, for any value of (m,n,k), we first determine the upper bound(theoretical best),
based on MR and NR. We find the value in the thread list that is just greater than or
equal to the theoretical best.
Ex : theoretical_threads = 70 will be subjected to inspection inside the condition
"theoretical_threads <= 96".
"theoretical_threads <= 128".
Inside this condition, the optimal numer of threads is decided among all the thread
values <= 96 in the list, based on patterns seen in 'k', 'm' and 'n'.
values <= 128 in the list, based on patterns seen in 'k', 'm' and 'n'.
P.S : This logic can further be experimented on, by considering a continuous list of
threads values rather than discrete. This way, the only limiting factor will be 'k'.
We could also experiment with different thread factorization methods, to achieve
optimal work distribution.
P.S : This logic is built over the assumption that the thread factorization is optimal.
As of now, we use the bli_thread_partition_2x2_fast( ... ) function for partitioning,
which is not always optimal. Thus, there would be scenarios where the optimal number
could be greater than the theoretical best. An alternative for this fast partitioning
is the bli_thread_partition_2x2_slow( ... ) function, which always gives the ideal
thread factorization.
*/
// Set the kernel dimensions
dim_t MR = 12, NR = 4;
// Calculate theoretical threads for constraint checking
dim_t theoretical_threads = ( ( m + MR - 1 ) / MR ) * ( ( n + NR - 1 ) / NR );
// Abstracting common subexpressions for constraints
dim_t k_le_192 = ( k <= 192 );
dim_t k_le_96 = ( k <= 96 );
dim_t k_le_48 = ( k <= 48 );
dim_t k_le_768 = ( k <= 768 );
dim_t k_le_384 = ( k <= 384 );
dim_t k_le_16 = ( k <= 16 );
dim_t k_le_24 = ( k <= 24 );
dim_t n_le_24 = ( n <= 24 );
dim_t n_le_16 = ( n <= 16 );
dim_t n_le_8 = ( n <= 8 );
dim_t n_le_128 = ( n <= 128 );
dim_t n_le_808 = ( n <= 808 );
dim_t m_le_24 = ( m <= 24 );
dim_t m_le_552 = ( m <= 552 );
// Cascading constraint-based rules
if ( theoretical_threads <= 2 )
{
if ( k <= 96 )
if ( k_le_192 || m_le_24 )
n_threads_ideal = 1;
else
n_threads_ideal = 2;
}
else if ( theoretical_threads <= 4 )
{
if ( k <= 24 )
n_threads_ideal = 1;
else if ( k <= 48 )
if ( k_le_48 && ( m <= 48 ) )
n_threads_ideal = 2;
else if ( k_le_192 )
n_threads_ideal = 1;
else
n_threads_ideal = 4;
}
else if ( theoretical_threads <= 8 )
{
if ( k <= 12 )
n_threads_ideal = 2;
else if ( k <= 24 )
if ( k_le_192 )
n_threads_ideal = 1;
else if ( n <= 32 )
n_threads_ideal = 4;
else
{
if ( ( ( n <= 12 ) && ( k <= 48 ) ) )
n_threads_ideal = 4;
else
n_threads_ideal = 8;
}
n_threads_ideal = 8;
}
else if ( theoretical_threads <= 16 )
{
if ( k <= 384 )
if ( ( k_le_96 && n_le_16 ) ||
( k_le_192 && n_le_16 && ( m <= 120 ) ) )
n_threads_ideal = 2;
else if ( k_le_24 && k_le_192 )
n_threads_ideal = 4;
else if ( ( k > 192 ) && ( k <= 3072 ) && ( m > 168 ) )
n_threads_ideal = 16;
else
n_threads_ideal = 8;
}
else if ( theoretical_threads <= 32 )
{
if ( ( ( k_le_96 && n_le_24 && ( m <= 192 ) && n_le_16 ) ||
( k_le_16 && n_le_24 ) ||
( k_le_192 && n_le_24 && ( m <= 120 ) ) ||
k_le_16 ) )
n_threads_ideal = 4;
else if ( ( ( k <= 3072 ) && n_le_24 ) ||
( n_le_24 && ( m <= 288 ) ) ||
k_le_192 )
n_threads_ideal = 8;
else if ( n_le_24 ||
( ( n <= 120 ) && ( k <= 1536 ) ) ||
( n <= 80 ) )
n_threads_ideal = 16;
else if ( ( k_le_96 && n_le_24 && ( m <= 192 ) ) ||
k_le_24 ||
( n <= 120 ) )
n_threads_ideal = 32;
else
n_threads_ideal = 48;
}
else if ( theoretical_threads <= 48 )
{
if ( k_le_192 )
n_threads_ideal = 8;
else if ( ( ( n_le_8 && m_le_552 && k_le_384 ) ||
( ( m <= 48 ) && m_le_552 && k_le_384 ) ||
( ( n <= 72 ) && m_le_552 ) ||
( ( n <= 144 ) && k_le_384 ) ) )
n_threads_ideal = 48;
else if ( ( ( n_le_16 && m_le_552 && k_le_384 ) ||
( m_le_552 && k_le_384 ) ||
( m <= 576 ) ) )
n_threads_ideal = 32;
else if ( ( ( n <= 80 ) && m_le_552 ) || ( n <= 88 ) )
n_threads_ideal = 64;
else
n_threads_ideal = 32;
}
else if ( theoretical_threads <= 64 )
{
if ( k_le_48 )
n_threads_ideal = 8;
else if ( k_le_96 )
n_threads_ideal = 16;
else if ( ( ( n_le_8 && ( m <= 648 ) ) ||
( ( n <= 40 ) && k_le_192 ) ||
k_le_192 ||
k_le_768 ) )
n_threads_ideal = 32;
else if ( ( m <= 696 ) ||
k_le_384 ||
( ( n <= 92 ) && k_le_768 ) )
n_threads_ideal = 48;
else if ( n_le_8 ||
( ( n <= 32 ) && ( m <= 96 ) ) )
n_threads_ideal = 64;
else if ( ( n <= 40 ) ||
( n <= 32 ) ||
n_le_128 )
n_threads_ideal = 128;
else
n_threads_ideal = 48;
}
else if ( theoretical_threads <= 128 )
{
if ( k_le_96 )
n_threads_ideal = 8;
else if ( k_le_192 )
n_threads_ideal = 32;
else if ( ( ( ( n <= 224 ) && ( n <= 264 ) && ( n <= 480 ) ) ||
( n <= 480 ) ||
( ( n <= 384 ) && ( n <= 448 ) && m_le_24 ) ||
( ( m <= 1152 ) && n_le_128 ) ) )
n_threads_ideal = 48;
else if ( ( ( ( n <= 264 ) && ( n <= 480 ) ) ||
k_le_768 ||
( ( n <= 328 ) && m_le_24 ) ||
m_le_24 ||
( ( k <= 1536 ) && n_le_128 ) ||
( ( m <= 576 ) && n_le_128 ) ||
( ( m <= 864 ) && n_le_128 ) ) )
n_threads_ideal = 64;
else if ( ( ( n <= 448 ) && m_le_24 ) ||
( ( m <= 768 ) && n_le_128 ) ||
n_le_128 )
n_threads_ideal = 128;
else
n_threads_ideal = 128;
}
else if ( theoretical_threads <= 256 )
{
if ( ( ( k_le_16 && ( m <= 624 ) ) ||
k_le_16 ||
k_le_48 ) )
n_threads_ideal = 16;
else if ( k_le_96 )
n_threads_ideal = 32;
else if ( ( m <= 2304 ) && k_le_192 )
n_threads_ideal = 48;
else if ( ( k_le_192 ||
( n_le_8 && k_le_768 ) ||
( m_le_24 && k_le_768 ) ||
( m_le_24 && n_le_808 ) ||
( n_le_8 && ( m <= 2304 ) ) ) )
n_threads_ideal = 64;
else if ( ( k_le_768 ||
( n_le_8 && ( m <= 2304 ) && ( k <= 1536 ) ) ||
n_le_8 ||
( ( k <= 1536 ) && n_le_808 ) ) )
n_threads_ideal = 128;
else if ( ( n_le_8 && ( m <= 2448 ) ) ||
n_le_808 )
n_threads_ideal = 256;
else
n_threads_ideal = 128;
}
else
{
if ( ( theoretical_threads <= 40 ) && k_le_48 )
n_threads_ideal = 4;
else if ( ( ( ( theoretical_threads <= 56 ) && k_le_48 ) ||
( ( theoretical_threads <= 56 ) && k_le_192 ) ||
( ( theoretical_threads <= 40 ) && n_le_24 ) ) )
n_threads_ideal = 8;
else if ( k_le_48 )
n_threads_ideal = 16;
else if ( k_le_96 )
n_threads_ideal = 32;
else if ( k_le_192 ||
( theoretical_threads <= 56 ) )
n_threads_ideal = 48;
else if ( ( ( ( theoretical_threads <= 192 ) && k_le_768 ) ||
( ( n_le_8 || m_le_24 ) && k_le_768 ) ||
( theoretical_threads <= 192 ) ||
( m_le_24 && n_le_808 ) ) )
n_threads_ideal = 64;
else if ( k_le_768 ||
( n_le_8 && n_le_808 ) )
n_threads_ideal = 128;
else
n_threads_ideal = 256;
}
}
else if( id == BLIS_ARCH_ZEN4 )
{
// Set the kernel dimensions
dim_t MR = 12, NR = 4;
// Calculate theoretical threads for constraint checking
dim_t theoretical_threads = ( ( m + MR - 1 ) / MR ) * ( ( n + NR - 1 ) / NR );
// Abstracting common subexpressions for constraints
dim_t k_le_192 = ( k <= 192 );
dim_t k_le_96 = ( k <= 96 );
dim_t k_le_48 = ( k <= 48 );
dim_t k_le_768 = ( k <= 768 );
dim_t k_le_384 = ( k <= 384 );
dim_t k_le_24 = ( k <= 24 );
dim_t k_le_16 = ( k <= 16 );
dim_t n_le_24 = ( n <= 24 );
dim_t n_le_16 = ( n <= 16 );
dim_t n_le_8 = ( n <= 8 );
dim_t n_le_48 = ( n <= 48 );
dim_t n_le_32 = ( n <= 32 );
dim_t m_le_24 = ( m <= 24 );
dim_t m_le_96 = ( m <= 96 );
dim_t m_le_192 = ( m <= 192 );
dim_t m_le_48 = ( m <= 48 );
// Cascading constraint-based rules
if ( theoretical_threads <= 2 )
{
if ( ( k_le_96 ) || ( n > 8 ) )
n_threads_ideal = 2;
else
n_threads_ideal = 1;
}
else if ( theoretical_threads <= 4 )
{
if ( k_le_24 || ( k_le_48 && ( n > 16 ) ) )
n_threads_ideal = 2;
else if ( k_le_48 )
n_threads_ideal = 1;
else if ( n_le_8 )
n_threads_ideal = 4;
else
n_threads_ideal = 4;
}
else if ( theoretical_threads <= 8 )
{
if ( k_le_24 && ( n_le_8 || m_le_24 ) )
n_threads_ideal = 1;
else if ( k_le_24 )
n_threads_ideal = 2;
else if ( n_le_8 || k_le_48 )
n_threads_ideal = 1;
else if ( k_le_384 )
n_threads_ideal = 4;
else if ( n_le_24 )
n_threads_ideal = 16;
else
n_threads_ideal = 8;
}
else if ( theoretical_threads <= 16 )
{
if ( k_le_384 )
{
if ( k <= 12 )
if ( k_le_16 )
{
if ( ( n <= 48 ) && ( m <= 48 ) )
if ( n_le_8 )
n_threads_ideal = 4;
else if ( ( n <= 40 ) )
n_threads_ideal = 2;
else
n_threads_ideal = 4;
}
else
{
if ( ( k <= 192 ) || ( m <= 36 ) || ( n <= 8 ) )
n_threads_ideal = 8;
else
n_threads_ideal = 16;
}
}
else
{
if ( n <= 20 )
else if ( k_le_24 && n_le_16 && ( m <= 72 ) )
n_threads_ideal = 2;
else if ( k_le_24 )
n_threads_ideal = 8;
else if ( k_le_192 || m_le_24 || ( m > 96 ) )
n_threads_ideal = 8;
else
n_threads_ideal = 16;
}
else if ( n_le_16 && ( n_le_8 || m_le_48 ) )
n_threads_ideal = 16;
else if ( n_le_16 )
n_threads_ideal = 32;
else
n_threads_ideal = 32;
}
else if ( theoretical_threads <= 32 )
{
if ( k <= 192 )
if ( k_le_192 )
{
if ( ( k <= 96 ) || ( m <= 300 ) )
if ( k_le_96 )
n_threads_ideal = 8;
else if ( ( m <= 288 ) )
{
if ( ( m <= 168 ) )
{
if ( ( n_le_48 && m_le_48 ) || ( n > 104 ) )
n_threads_ideal = 16;
else
n_threads_ideal = 8;
}
else
n_threads_ideal = 8;
}
else
n_threads_ideal = 32;
}
else
else if ( n_le_16 && ( n_le_8 || m_le_96 ) )
n_threads_ideal = 32;
else if ( n_le_16 )
n_threads_ideal = 48;
else if ( ( n <= 112 ) )
{
if ( n <= 20 )
n_threads_ideal = 8;
else if ( n <= 116 )
if ( n_le_32 && ( m_le_48 || ( n > 24 ) ) )
n_threads_ideal = 32;
else if ( n_le_32 )
n_threads_ideal = 48;
else if ( k_le_384 && ( n <= 56 ) )
n_threads_ideal = 48;
else if ( k_le_384 )
n_threads_ideal = 16;
else
n_threads_ideal = 32;
n_threads_ideal = 48;
}
else if ( ( n <= 128 ) && ( k <= 1536 ) )
n_threads_ideal = 16;
else if ( ( n <= 128 ) )
n_threads_ideal = 48;
else
n_threads_ideal = 32;
}
else if ( theoretical_threads <= 48 )
{
if ( k <= 96 )
{
if ( k_le_96 )
n_threads_ideal = 8;
}
else
else if ( n_le_8 && ( m <= 576 ) )
n_threads_ideal = 32;
else if ( n_le_8 )
n_threads_ideal = 48;
else if ( k_le_384 )
{
if ( n <= 8 )
if ( k_le_192 )
{
if ( m <= 540 )
if ( m_le_24 )
n_threads_ideal = 32;
else
n_threads_ideal = 48;
}
else
{
if ( k <= 384 )
else if ( m_le_192 )
{
if ( n <= 72 )
{
if ( m <= 192 )
n_threads_ideal = 48;
else if ( k <= 192 )
n_threads_ideal = 32;
else
n_threads_ideal = 48;
}
if ( m_le_48 || ( m > 120 ) )
n_threads_ideal = 48;
else
{
if ( k <= 192 )
n_threads_ideal = 32;
else if ( n <= 164 )
n_threads_ideal = 32;
else
n_threads_ideal = 48;
}
n_threads_ideal = 32;
}
else
{
n_threads_ideal = 48;
}
n_threads_ideal = 32;
}
else if ( ( n <= 64 ) && m_le_192 )
n_threads_ideal = 48;
else if ( ( n <= 64 ) )
n_threads_ideal = 96;
else if ( ( n <= 160 ) )
n_threads_ideal = 32;
else
n_threads_ideal = 96;
}
else if ( n_le_24 && m_le_96 )
n_threads_ideal = 48;
else if ( n_le_24 && m_le_192 && n_le_16 )
n_threads_ideal = 48;
else
n_threads_ideal = 96;
}
else if ( theoretical_threads <= 96 )
{
if ( k <= 48 )
{
if ( k_le_48 )
n_threads_ideal = 8;
}
else if ( k <= 768 )
else if ( k_le_768 )
{
if ( k <= 96 )
if ( k_le_192 )
{
n_threads_ideal = 32;
}
else
{
if ( n <= 164 )
{
if ( m <= 576 )
n_threads_ideal = 48;
else
n_threads_ideal = 96;
}
else
{
if ( k <= 192 )
{
if ( n <= 268 )
n_threads_ideal = 32;
else
n_threads_ideal = 48;
}
else
{
n_threads_ideal = 48;
}
}
}
}
else
{
n_threads_ideal = 96;
}
}
else if ( theoretical_threads <= 192 )
{
if ( k <= 96 )
{
if ( k <= 48 )
n_threads_ideal = 8;
else
n_threads_ideal = 32;
}
else if ( k <= 384 )
{
if ( k <= 192 )
{
if ( n <= 384 )
{
if ( n <= 14 )
n_threads_ideal = 48;
else
n_threads_ideal = 96;
}
else
{
n_threads_ideal = 48;
}
}
else
{
n_threads_ideal = 96;
}
}
else
{
if ( m <= 1146 )
{
if ( n <= 270 )
{
if ( k <= 768 )
{
n_threads_ideal = 192;
}
else
{
if ( m <= 54 )
n_threads_ideal = 96;
else
n_threads_ideal = 192;
}
}
else
{
if ( n <= 642 )
n_threads_ideal = 96;
else
n_threads_ideal = 192;
}
}
else
{
n_threads_ideal = 192;
}
}
}
// In case the theoretical threads is greater than 192, we subject the inputs
// to a set of heuristics derived based on patterns in the inputs.
else
{
if ( k <= 192 )
{
if ( k <= 24 )
{
if ( n <= 2376 )
{
n_threads_ideal = 8;
}
else
{
if ( k <= 16 )
n_threads_ideal = 8;
else
n_threads_ideal = 16;
}
}
else if ( k <= 48 )
{
if ( n <= 64 )
{
if ( m <= 2736 )
{
n_threads_ideal = 32;
}
else
{
n_threads_ideal = 96;
}
}
else
{
n_threads_ideal = 96;
}
}
else if ( k <= 96 )
{
if ( n <= 72 )
if ( k_le_96 )
n_threads_ideal = 32;
else if ( n_le_8 || m_le_24 )
n_threads_ideal = 32;
else
n_threads_ideal = 96;
}
else
{
else if ( ( n <= 152 ) && ( m <= 576 ) )
n_threads_ideal = 96;
}
else
n_threads_ideal = 48;
}
else if ( k <= 384 )
else if ( ( m <= 864 ) )
{
if ( m <= 5892 )
if ( ( n <= 104 ) )
{
if ( n <= 16 )
if ( ( m <= 384 ) )
{
if ( n_le_24 && ( m_le_192 || n_le_16 ) )
n_threads_ideal = 96;
else if ( n_le_24 )
n_threads_ideal = 192;
else if ( m_le_96 && n_le_48 )
n_threads_ideal = 96;
else
n_threads_ideal = 192;
}
else
n_threads_ideal = 192;
}
else if ( ( k <= 1536 ) && ( n <= 256 ) && ( n <= 192 ) )
n_threads_ideal = 192;
else if ( ( k <= 1536 ) && ( n <= 256 ) )
n_threads_ideal = 48;
else
n_threads_ideal = 192;
}
else
n_threads_ideal = 48;
}
else if ( theoretical_threads <= 192 )
{
if ( k_le_48 && ( k_le_24 || ( n <= 672 ) ) )
n_threads_ideal = 8;
else if ( k_le_48 )
n_threads_ideal = 16;
else if ( k_le_192 )
{
if ( k_le_96 )
n_threads_ideal = 32;
else if ( ( n <= 380 ) && n_le_16 )
n_threads_ideal = 48;
else if ( ( n <= 380 ) )
n_threads_ideal = 96;
else
n_threads_ideal = 48;
}
else if ( k_le_768 )
{
if ( k_le_384 )
n_threads_ideal = 96;
else if ( ( m <= 1152 ) && ( n <= 344 ) )
n_threads_ideal = 192;
else
n_threads_ideal = 96;
}
else if ( ( m <= 1152 ) && ( n <= 344 ) && ( k > 3072 ) )
n_threads_ideal = 192;
else if ( ( m <= 1152 ) )
n_threads_ideal = 96;
else if ( ( k <= 3072 ) || ( m <= 1512 ) )
n_threads_ideal = 96;
else
n_threads_ideal = 192;
}
else
{
if ( k_le_192 )
{
if ( k_le_24 )
n_threads_ideal = 8;
else if ( k_le_48 )
{
if ( ( n <= 56 ) )
{
if ( ( m <= 2544 ) && n_le_48 )
n_threads_ideal = 8;
else if ( ( m <= 2544 ) )
n_threads_ideal = 32;
else
n_threads_ideal = 16;
}
else
n_threads_ideal = 16;
}
else if ( k_le_96 )
{
if ( n_le_48 && ( m <= 768 ) )
n_threads_ideal = 96;
else if ( n_le_48 )
n_threads_ideal = 32;
else
n_threads_ideal = 32;
}
else if ( ( m <= 2040 ) )
{
if ( ( n <= 1104 ) && m_le_48 )
n_threads_ideal = 48;
else if ( ( n <= 1104 ) && ( m <= 936 ) )
n_threads_ideal = 96;
else if ( ( n <= 1104 ) )
n_threads_ideal = 48;
else
n_threads_ideal = 96;
}
else
n_threads_ideal = 96;
}
else if ( k_le_384 )
{
if ( ( m <= 4896 ) )
{
if ( n_le_8 )
n_threads_ideal = 96;
else if ( ( n <= 1632 ) && m_le_24 )
n_threads_ideal = 96;
else
n_threads_ideal = 192;
}
else
{
n_threads_ideal = 192;
}
}
else
{
n_threads_ideal = 192;
}
}
}
else // Not BLIS_ARCH_ZEN5 or BLIS_ARCH_ZEN4

View File

@@ -87,10 +87,8 @@ bool bli_cntx_gemmsup_thresh_is_met_zen4( obj_t* a, obj_t* b, obj_t* c, cntx_t*
// The threshold for m is a single value, but for n, it is
// also based on the packing size of A, since the kernels are
// column preferential
if( ( ( m <= 1380 ) || ( n <= 1520 ) || ( k <= 128 ) ) ) return TRUE;
if( ( m <= 1380 ) || ( n <= 1520 ) || ( k <= 128 ) ) return TRUE;
// For all combinations in small sizes
if( ( m <= 512 ) && ( n <= 512 ) && ( k <= 512 ) ) return TRUE;
return FALSE;
}
else if( dt == BLIS_SCOMPLEX )

View File

@@ -87,10 +87,8 @@ bool bli_cntx_gemmsup_thresh_is_met_zen5( obj_t* a, obj_t* b, obj_t* c, cntx_t*
// The threshold for m is a single value, but for n, it is
// also based on the packing size of A, since the kernels are
// column preferential
if( ( m <= 60 ) || ( ( n <= 60 ) && ( m <= 960 ) && ( k <= 16384 ) ) || ( k <= 8 ) ) return TRUE;
if( ( m <= 1380 ) || ( n <= 1520 ) || ( k <= 128 ) ) return TRUE;
// For all combinations in small sizes
if( ( m <= 216 ) && ( n <= 216 ) && ( k <= 216 ) ) return TRUE;
return FALSE;
}
else if( dt == BLIS_SCOMPLEX )