diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 25134c207..9ce8ba25e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -1201,6 +1201,692 @@ void bli_nthreads_optimum( } } } + else if( family == BLIS_GEMM && bli_obj_is_scomplex( c ) ) + { + // Acquire the input dimensions for ideal thread selection + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + if( id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4 ) + { + /* + The logic for ideal thread selection is as follows: + Every GEMM kernel performs matrix multiplication(single-threaded) on an + MRxNR block of C, MRxk block of A, and kxNR block of B. + + Thus, the upper bound on the number of threads is ceil(m/MR) * ceil(n/NR). + This is because the framework will block the data into ceil(m/MR) panels along + the "m" direction and ceil(n/NR) panels along the "n" direction. + + In reality, the ideal number of threads could be lesser than ceil(m/MR) * ceil(n/NR), + based on the 'k' value. For small 'k', each MR×NR tile requires less computation per + data loaded (low arithmetic intensity), so memory bandwidth is a limiting factor. + Too many threads may saturate memory bandwidth, causing contention and reducing + efficiency. For large 'k', arithmetic intensity increases. Threads spend more time + computing per unit of data loaded, so you may be able to utilize more threads efficiently. + + Also, we have a candidate set of discrete thread values from which we choose the + best number of threads. That is, from the list of [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]. + + Thus, for any value of (m,n,k), we first determine the upper bound(theoretical best), + based on MR and NR. We find the value in the thread list that is just greater than or + equal to the theoretical best. + Ex : theoretical_threads = 70 will be subjected to inspection inside the condition + "theoretical_threads <= 96". + + Inside this condition, the optimal numer of threads is decided among all the thread + values <= 96 in the list, based on patterns seen in 'k', 'm' and 'n'. + + P.S : This logic can further be experimented on, by considering a continuous list of + threads values rather than discrete. This way, the only limiting factor will be 'k'. + We could also experiment with different thread factorization methods, to achieve + optimal work distribution. + */ + // Set the kernel dimensions + dim_t MR = 24; + dim_t NR = 4; + // Calculate theoretical threads for constraint checking + dim_t theoretical_threads = ( ( m + MR - 1 ) / MR ) * ( ( n + NR - 1 ) / NR ); + + // Cascading constraint-based rules + if ( theoretical_threads <= 2 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 2; + } + } + else if ( theoretical_threads <= 4 ) + { + if ( k <= 24 ) + { + n_threads_ideal = 1; + } + else + { + if ( n <= 12 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 4; + } + } + else + { + if ( n <= 16 ) + { + if ( k <= 96 ) + { + n_threads_ideal = 2; + } + else + { + n_threads_ideal = 4; + } + } + else + { + n_threads_ideal = 4; + } + } + } + } + else if ( theoretical_threads <= 8 ) + { + if ( k <= 24 ) + { + if ( k <= 12 ) + { + n_threads_ideal = 1; + } + else + { + if ( n <= 24 ) + { + if ( m <= 72 ) + { + n_threads_ideal = 1; + } + else + { + n_threads_ideal = 4; + } + } + else + { + n_threads_ideal = 4; + } + } + } + else + { + if ( n <= 12 ) + { + if ( k <= 48 ) + { + if ( n <= 8 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( k <= 384 ) + { + if ( n <= 32 ) + { + n_threads_ideal = 4; + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( n <= 32 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + } + } + else if ( theoretical_threads <= 16 ) + { + if ( k <= 192 ) + { + if ( k <= 12 ) + { + if ( n <= 16 ) + { + n_threads_ideal = 4; + } + else + { + n_threads_ideal = 4; + } + } + else + { + if ( k <= 24 ) + { + if ( m <= 240 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + if ( n <= 18 ) + { + if ( n <= 12 ) + { + if ( k <= 8192 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + else + { + if ( m <= 96 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 64 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + } + } + } + else if ( theoretical_threads <= 32 ) + { + if ( k <= 96 ) + { + if ( k <= 12 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 36 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + n_threads_ideal = 8; + } + } + else + { + if ( n <= 20 ) + { + if ( k <= 192 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 12 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 116 ) + { + if ( k <= 192 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 16; + } + } + } + } + else if ( theoretical_threads <= 48 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + n_threads_ideal = 8; + } + else + { + if ( n <= 24 ) + { + n_threads_ideal = 8; + } + else + { + n_threads_ideal = 8; + } + } + } + else + { + if ( n <= 8 ) + { + if ( k <= 4096 ) + { + if ( m <= 1032 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( k <= 192 ) + { + if ( n <= 88 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + } + } + } + else if ( theoretical_threads <= 96 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( n <= 48 ) + { + if ( n <= 16 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 16; + } + } + else + { + n_threads_ideal = 16; + } + } + else + { + if ( n <= 144 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 16; + } + else + { + n_threads_ideal = 32; + } + } + else + { + n_threads_ideal = 16; + } + } + } + else + { + if ( k <= 768 ) + { + if ( n <= 160 ) + { + if ( m <= 1152 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( n <= 294 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + } + else + { + if ( m <= 1596 ) + { + if ( n <= 98 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + else + { + n_threads_ideal = 48; + } + } + } + } + else if ( theoretical_threads <= 192 ) + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( m <= 480 ) + { + if ( n <= 272 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( n <= 16 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 252 ) + { + if ( m <= 408 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 48; + } + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( k <= 768 ) + { + if ( k <= 192 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + else + { + if ( m <= 2304 ) + { + if ( m <= 120 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 192; + } + } + else + { + if ( m <= 3936 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + } + } + // In case the theoretical threads is greater than 192, we subject the inputs + // to a set of heuristics derived based on patterns in the inputs. + else + { + if ( k <= 96 ) + { + if ( k <= 48 ) + { + if ( m <= 480 ) + { + if ( n <= 268 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + else + { + if ( n <= 16 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( n <= 252 ) + { + if ( m <= 408 ) + { + n_threads_ideal = 32; + } + else + { + n_threads_ideal = 48; + } + } + else + { + n_threads_ideal = 32; + } + } + } + else + { + if ( k <= 768 ) + { + if ( k <= 192 ) + { + if ( n <= 12 ) + { + n_threads_ideal = 48; + } + else + { + n_threads_ideal = 48; + } + } + else + { + if ( k <= 384 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + else + { + if ( m <= 2304 ) + { + if ( m <= 120 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 192; + } + } + else + { + if ( m <= 3936 ) + { + n_threads_ideal = 96; + } + else + { + n_threads_ideal = 96; + } + } + } + } + } + } + } else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) { dim_t m = bli_obj_length(c);