From dfc95d29fc680a95e390e32f33538a78d1901a60 Mon Sep 17 00:00:00 2001 From: Shubham Date: Tue, 7 Mar 2023 02:51:33 +0530 Subject: [PATCH] Enable DTRSM small multithreading path for BLAS interface - Enabled DTRSM small mt for sizes where performance is better than small or native. - Threshold Tuning for small path is updated. - Function signature for bli_trsm_small_mt has been made similar to bli_trsm_small so that one function pointer can be used for all functions. - Early return condition in DTRSM small for sizes > 1000 has been removed so that the sizes for which small path to take can be decided on bla layer instead of inside kernel. AMD-Internal: [CPUPL-2735] Change-Id: Ieea31343dc660517acc18c92713381a8b84d3a2f --- frame/3/bli_l3_sup_ker_prot.h | 4 +- frame/compat/bla_trsm_amd.c | 124 ++++++++++++++--------- kernels/zen/3/bli_trsm_small.c | 134 ++++++++++++++----------- kernels/zen/bli_kernels_zen.h | 3 +- kernels/zen4/3/bli_trsm_small_AVX512.c | 20 ++-- kernels/zen4/bli_kernels_zen4.h | 14 ++- 6 files changed, 177 insertions(+), 122 deletions(-) diff --git a/frame/3/bli_l3_sup_ker_prot.h b/frame/3/bli_l3_sup_ker_prot.h index afa8f8ace..9643e04bd 100644 --- a/frame/3/bli_l3_sup_ker_prot.h +++ b/frame/3/bli_l3_sup_ker_prot.h @@ -56,7 +56,7 @@ void PASTEMAC(ch,opname) \ -#define TRSMSUP_PROT( opname ) \ +#define TRSMSMALL_PROT( opname ) \ \ err_t PASTEMAC0(opname) \ ( \ @@ -69,7 +69,7 @@ err_t PASTEMAC0(opname) \ bool is_parallel \ ); -#define TRSMSUP_KER_PROT( ch, opname ) \ +#define TRSMSMALL_KER_PROT( ch, opname ) \ \ BLIS_INLINE err_t PASTEMAC(ch,opname) \ ( \ diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index d59bb7bab..1e88c9536 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -941,81 +941,113 @@ void dtrsm_blis_impl bli_obj_set_struc( struca, &ao ); #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == TRUE) { - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 single thread implementation - * is doing better than native multithread */ + // typedef for trsm small kernel function pointer + typedef err_t (*dtrsm_small_ker_ft) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl, + bool is_parallel + ); + err_t status = BLIS_NOT_YET_IMPLEMENTED; + + // trsm small kernel function pointer definition + dtrsm_small_ker_ft ker_ft = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + // dimensions of triangular matrix + // for left variants, dim_a is m0, + // for right variants, dim_a is n0 + dim_t dim_a = n0; + if (blis_side == BLIS_LEFT) + dim_a = m0; + + // size of output matrix(B) + dim_t size_b = m0*n0; + + /* bli_dtrsm_small is performing better than existing native + * implementations for dim_a<1500 and m0*n0<5e6 for single thread. + * In case of multithread when [m+n]<320 single thread implementation + * is doing better than small multithread and native multithread */ bool is_parallel = bli_thread_get_is_parallel(); - if ((!is_parallel && m0<=1000 && n0<=1000) || + if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || (is_parallel && (m0+n0)<320)) { - err_t status; - - // Query the architecture ID - arch_t id = bli_arch_query_id(); switch(id) { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) // check if variant is RUN[N/U] or RLT[N/U] // this is a temporary fix, will be removed when all variants are added - - // for n < 200 avx2 kernels are performing better, but if - // n is a multiple of 8 then there will be no fringe case for avx512, - // in such cases avx512 kernels will perform better. if( (blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) { - status = bli_trsm_small_AVX512( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL, - is_parallel); + ker_ft = bli_trsm_small_AVX512; } else { - status = bli_trsm_small( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL, - is_parallel); + ker_ft = bli_trsm_small; } break; -#endif - +#endif // BLIS_KERNELS_ZEN4 case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN3: - status = bli_trsm_small( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL, - is_parallel); - break; default: - status = BLIS_NOT_YET_IMPLEMENTED; + ker_ft = bli_trsm_small; + break; } - if (status == BLIS_SUCCESS) + } + +#ifdef BLIS_ENABLE_OPENMP + if( (ker_ft == NULL) && (is_parallel) && + ((dim_a < 2500) && (size_b < 5e6)) ) + { + switch(id) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + if( (blis_side == BLIS_RIGHT) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + else + { + ker_ft = bli_trsm_small_mt; + } + break; +#endif// BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + ker_ft = bli_trsm_small_mt; + break; } } + +#endif// BLIS_ENABLE_OPENMP + if(ker_ft) + { + status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel); + } + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 12c0ee729..8cb7bb786 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -5264,9 +5264,6 @@ err_t bli_trsm_small { case BLIS_DOUBLE: { - if((!is_parallel) && (m > 1000 || n > 1000)) { - return BLIS_NOT_YET_IMPLEMENTED; - } break; } case BLIS_FLOAT: @@ -5345,7 +5342,8 @@ err_t bli_trsm_small_mt obj_t* a, obj_t* b, cntx_t* cntx, - cntl_t* cntl + cntl_t* cntl, + bool is_parallel ) { gint_t m = bli_obj_length( b ); // number of rows of matrix b @@ -5390,73 +5388,97 @@ err_t bli_trsm_small_mt if (n_threads < 0 ) n_threads = 1; - bool is_parallel = bli_thread_get_is_parallel(); - err_t status = BLIS_SUCCESS; _Pragma( "omp parallel num_threads(n_threads)" ) { // Query the thread's id from OpenMP. const dim_t tid = omp_get_thread_num(); + const dim_t nt_real = omp_get_num_threads(); - obj_t b_t; - dim_t start; // Each thread start Index - dim_t end; // Each thread end Index - thrinfo_t thread; - - thread.n_way = n_threads; - thread.work_id = tid; - thread.ocomm_id = tid; - - - // Compute start and end indexes of matrix partitioning for each thread - if ( bli_is_right( side ) ) + if(nt_real != n_threads) { - bli_thread_range_sub ( &thread, + if(tid == 0) + { + bli_trsm_small + ( + side, + alpha, + a, + b, + cntx, + cntl, + is_parallel + ); + } + } + else + { + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + + // Compute start and end indexes of matrix partitioning for each thread + if ( bli_is_right( side ) ) + { + bli_thread_range_sub + ( + &thread, m, d_mr,// Need to decide based on type FALSE, &start, &end - ); - // For each thread acquire matrix block on which they operate - // Data-based parallelism + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism - bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + else + { + bli_thread_range_sub + ( + &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small + ( + side, + alpha, + a, + &b_t, + NULL, + NULL, + is_parallel + ); + // To capture the error populated from any of the threads + if ( status_l != BLIS_SUCCESS ) + { + _Pragma("omp critical") + status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status; + } } - else - { - bli_thread_range_sub ( &thread, - n, - d_nr,// Need to decide based on type - FALSE, - &start, - &end - ); - // For each thread acquire matrix block on which they operate - // Data-based parallelism - - bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); - } - - // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to - // all threads - err_t status_l = BLIS_SUCCESS; - - status_l = bli_trsm_small - ( - side, - alpha, - a, - &b_t, - NULL, - NULL, - is_parallel - ); - // To capture the error populated from any of the threads - _Pragma( "omp critical" ) - status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status; } - return status; }// End of function #endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 3fc156a26..5c14b161a 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -345,7 +345,8 @@ err_t bli_trsm_small_mt obj_t* a, obj_t* b, cntx_t* cntx, - cntl_t* cntl + cntl_t* cntl, + bool is_parallel ); void bli_multi_sgemv_4x2 diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 0a70ef5f0..b9c4bd3b4 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -416,8 +416,6 @@ err_t bli_trsm_small_AVX512 ) { err_t err; - dim_t m = bli_obj_length(b); - dim_t n = bli_obj_width(b); bool uplo = bli_obj_is_upper(a); bool transa = bli_obj_has_trans(a); @@ -427,10 +425,6 @@ err_t bli_trsm_small_AVX512 { case BLIS_DOUBLE: { - if ((!is_parallel) && (m > 1200 || n > 1200)) - { - return BLIS_NOT_YET_IMPLEMENTED; - } break; } case BLIS_FLOAT: @@ -490,7 +484,8 @@ err_t bli_trsm_small_mt_AVX512 obj_t* a, obj_t* b, cntx_t* cntx, - cntl_t* cntl + cntl_t* cntl, + bool is_parallel ) { gint_t m = bli_obj_length(b); // number of rows of matrix b @@ -531,8 +526,6 @@ err_t bli_trsm_small_mt_AVX512 if (n_threads < 0) n_threads = 1; - bool is_parallel = bli_thread_get_is_parallel(); - err_t status = BLIS_SUCCESS; _Pragma("omp parallel num_threads(n_threads)") { @@ -546,7 +539,7 @@ err_t bli_trsm_small_mt_AVX512 { if(tid == 0) { - bli_trsm_small + bli_trsm_small_AVX512 ( side, alpha, @@ -618,8 +611,11 @@ err_t bli_trsm_small_mt_AVX512 is_parallel ); // To capture the error populated from any of the threads - _Pragma("omp critical") - status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status; + if ( status_l != BLIS_SUCCESS ) + { + _Pragma("omp critical") + status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status; + } } } diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h index 1abb9a8a1..fe29057ec 100644 --- a/kernels/zen4/bli_kernels_zen4.h +++ b/kernels/zen4/bli_kernels_zen4.h @@ -110,8 +110,12 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_3x32_avx512 ) GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x32_avx512 ) GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x32_avx512 ) -TRSMSUP_PROT(trsm_small_AVX512) -TRSMSUP_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 ) -TRSMSUP_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 ) -TRSMSUP_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 ) -TRSMSUP_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 ) +TRSMSMALL_PROT(trsm_small_AVX512) +TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 ) +TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 ) +TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 ) +TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 ) + +#ifdef BLIS_ENABLE_OPENMP +TRSMSMALL_PROT(trsm_small_mt_AVX512) +#endif \ No newline at end of file