From 15c44a6f8cad1bf87ee2fe734bfad69a76eae2af Mon Sep 17 00:00:00 2001 From: "Balasubramanian, Vignesh" Date: Wed, 25 Jun 2025 10:05:40 +0530 Subject: [PATCH] Adding dynamic thread-setting logic for CGEMM(AOCL_DYNAMIC) (#48) - Added a set of thresholds(based on input dimensions) that determine and set the ideal number of threads to be used for CGEMM (on ZEN4 and ZEN5 architectures). - The thread-setting logic is as follows : - The underlying kernels(single-threaded) work on blocks of MRxk of A, kxNR of B and MRxNR of C. Thus, it is initially assumed that the optimal number of threads is ceil(m/MR)*ceil(n/NR). This is the upper bound on the actual number of threads that is ideal. - The actual ideal thread count could be lesser than the upper bound, based on the work that every thread receives. This is mainly determined by the value of 'k'. - If 'k' is small, the arithmetic intensity(AI) is low and memory bandwidth becomes the limiting factor, thus favoring smaller thread counts. In contrast, if 'k' is high, the AI is high and the workload scales well with higher thread counts. - So, we limit the number of threads when 'k' is small to avoid bandwidth contention. Using fewer threads ensures each thread gets more bandwidth, improving efficiency. In contrast, we allow more threads when 'k' is large, as the computation becomes more compute-bound and less limited by memory bandwidth, thereby benefitting with a higher-thread count. - The new logic will now set the upper bound for the optimal number of threads (based on the number of tiles), and then further reduce it based on the values of 'm', 'n' and 'k'. This comes under the 'AOCL_DYNAMIC' feature for CGEMM, specifically for ZEN4 and ZEN5 architectures. AMD-Internal: [CPUPL-6498] Co-authored-by: Vignesh Balasubramanian Co-authored-by: Varaganti, Kiran --- frame/base/bli_rntm.c | 686 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 686 insertions(+) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 25134c207..9ce8ba25e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -1201,6 +1201,692 @@ void bli_nthreads_optimum( } } } + else if( family == BLIS_GEMM && bli_obj_is_scomplex( c ) ) + { + // Acquire the input dimensions for ideal thread selection + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + if( id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4 ) + { + /* + The logic for ideal thread selection is as follows: + Every GEMM kernel performs matrix multiplication(single-threaded) on an + MRxNR block of C, MRxk block of A, and kxNR block of B. + + Thus, the upper bound on the number of threads is ceil(m/MR) * ceil(n/NR). + This is because the framework will block the data into ceil(m/MR) panels along + the "m" direction and ceil(n/NR) panels along the "n" direction. + + In reality, the ideal number of threads could be lesser than ceil(m/MR) * ceil(n/NR), + based on the 'k' value. For small 'k', each MR×NR tile requires less computation per + data loaded (low arithmetic intensity), so memory bandwidth is a limiting factor. + Too many threads may saturate memory bandwidth, causing contention and reducing + efficiency. For large 'k', arithmetic intensity increases. Threads spend more time + computing per unit of data loaded, so you may be able to utilize more threads efficiently. + + Also, we have a candidate set of discrete thread values from which we choose the + best number of threads. That is, from the list of [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]. + + Thus, for any value of (m,n,k), we first determine the upper bound(theoretical best), + based on MR and NR. We find the value in the thread list that is just greater than or + equal to the theoretical best. + Ex : theoretical_threads = 70 will be subjected to inspection inside the condition + "theoretical_threads <= 96". + + Inside this condition, the optimal numer of threads is decided among all the thread + values <= 96 in the list, based on patterns seen in 'k', 'm' and 'n'. + + P.S : This logic can further be experimented on, by considering a continuous list of + threads values rather than discrete. This way, the only limiting factor will be 'k'. + We could also experiment with different thread factorization methods, to achieve + optimal work distribution. + */ + // Set the kernel dimensions + dim_t MR = 24; + dim_t NR = 4; + // Calculate theoretical threads for constraint checking + dim_t theoretical_threads = ( ( m + MR - 1 ) / MR ) * ( ( n + NR - 1 ) / NR ); + + // Cascading constraint-based rules + if ( theoretical_threads <= 2 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 2; + } + } + else if ( theoretical_threads <= 4 ) + { + if ( k <= 24 ) + { + n_threads_ideal = 1; + } + else + { + if ( n <= 12 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 4; + } + } + else + { + if ( n <= 16 ) + { + if ( k <= 96 ) + { + n_threads_ideal = 2; + } + else + { + n_threads_ideal = 4; + } + } + else + { + n_threads_ideal = 4; + } + } + } + } + else if ( theoretical_threads <= 8 ) + { + if ( k <= 24 ) + { + if ( k <= 12 ) + { + n_threads_ideal = 1; + } + else + { + if ( n <= 24 ) + { + if ( m <= 72 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 4; + } + } + else + { + n_threads_ideal = 4; + } + } + } + else + { + if ( n <= 12 ) + { + if ( k <= 48 ) + { + if ( n <= 8 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( k <= 384 ) + { + if ( n <= 32 ) + { + n_threads_ideal = 4; + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( n <= 32 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + } + } + else if ( theoretical_threads <= 16 ) + { + if ( k <= 192 ) + { + if ( k <= 12 ) + { + if ( n <= 16 ) + { + n_threads_ideal = 4; + } + else + { + n_threads_ideal = 4; + } + } + else + { + if ( k <= 24 ) + { + if ( m <= 240 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + if ( n <= 18 ) + { + if ( n <= 12 ) + { + if ( k <= 8192 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + else + { + if ( m <= 96 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 64 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + } + } + } + else if ( theoretical_threads <= 32 ) + { + if ( k <= 96 ) + { + if ( k <= 12 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 36 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( n <= 20 ) + { + if ( k <= 192 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 12 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 116 ) + { + if ( k <= 192 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 16; + } + } + } + } + else if ( theoretical_threads <= 48 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 24 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + if ( n <= 8 ) + { + if ( k <= 4096 ) + { + if ( m <= 1032 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( k <= 192 ) + { + if ( n <= 88 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + } + } + } + else if ( theoretical_threads <= 96 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( n <= 48 ) + { + if ( n <= 16 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + else + { + n_threads_ideal = 16; + } + } + else + { + if ( n <= 144 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 16; + } + } + } + else + { + if ( k <= 768 ) + { + if ( n <= 160 ) + { + if ( m <= 1152 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( n <= 294 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + } + else + { + if ( m <= 1596 ) + { + if ( n <= 98 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + else + { + n_threads_ideal = 48; + } + } + } + } + else if ( theoretical_threads <= 192 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( m <= 480 ) + { + if ( n <= 272 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( n <= 16 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 252 ) + { + if ( m <= 408 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 48; + } + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( k <= 768 ) + { + if ( k <= 192 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + else + { + if ( m <= 2304 ) + { + if ( m <= 120 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 192; + } + } + else + { + if ( m <= 3936 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + } + } + // In case the theoretical threads is greater than 192, we subject the inputs + // to a set of heuristics derived based on patterns in the inputs. + else + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( m <= 480 ) + { + if ( n <= 268 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( n <= 16 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 252 ) + { + if ( m <= 408 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 48; + } + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( k <= 768 ) + { + if ( k <= 192 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + else + { + if ( m <= 2304 ) + { + if ( m <= 120 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 192; + } + } + else + { + if ( m <= 3936 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + } + } + } + } else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) { dim_t m = bli_obj_length(c);