From 1a3428ddfc106d1925bcd76f02de79d4e84babf3 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 8 Apr 2022 13:19:34 +0530 Subject: [PATCH] Parallelization of dtrsm_small routine 1. Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right). 2. Fine-tuning with AOCL_DYNAMIC to achieve better performance. AMD-Internal: [CPUPL-2103] Change-Id: I6be6a2b579de7df9a3141e0d68bdf3e8a869a005 --- frame/base/bli_rntm.c | 15 +++- frame/compat/bla_trsm_amd.c | 41 ++++++++- kernels/zen/3/bli_trsm_small.c | 147 ++++++++++++++++++++++++++++++--- kernels/zen/bli_kernels_zen.h | 14 +++- 4 files changed, 201 insertions(+), 16 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index f8e00c620..c15650e91 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -631,13 +631,22 @@ void bli_nthreads_optimum( else n_threads_ideal = n_threads; } - else if( family == BLIS_TRSM && bli_obj_is_double(c)) + else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); - if(m<=512 && n<=512) - n_threads_ideal = 4; +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; +#else + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; +#endif } else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) { diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index e1a2fffaf..3b3850928 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -395,7 +395,7 @@ void strsm_ ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', *side, *uploa,*transa, *diaga, *m, *n, (void*)alpha,*lda, *ldb); @@ -886,8 +886,45 @@ void dtrsm_ return; } } -#endif + + //bli_trsm_small_mt is performing better than native multithread + //for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + { + err_t status; + status = bli_trsm_small_mt + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } +#endif// BLIS_ENABLE_OPENMP +#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM + } + bli_trsmnat ( blis_side, diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 07077010f..f8c0ea591 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3821,15 +3821,22 @@ err_t bli_trsm_small num_t dt = bli_obj_dt(a); switch(dt) { - case BLIS_DOUBLE: - case BLIS_FLOAT: - case BLIS_SCOMPLEX: - { - if(m > 1000 || n > 1000) { + case BLIS_DOUBLE: + { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + { + if(m > 1000 || n > 1000) { return BLIS_NOT_YET_IMPLEMENTED; } break; - } + } case BLIS_DCOMPLEX: { if(m > 500 || n > 500) { @@ -3886,6 +3893,126 @@ err_t bli_trsm_small return err; }; +#ifdef BLIS_ENABLE_OPENMP +/* + * Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right) + */ + +err_t bli_trsm_small_mt +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + rntm_t rntm; + gint_t m = bli_obj_length( b ); // number of rows of matrix b + gint_t n = bli_obj_width( b ); // number of columns of Matrix b + dim_t d_mr = 8,d_nr = 6; + + num_t dt = bli_obj_dt(a); + switch(dt) + { + case BLIS_DOUBLE: + { + d_mr = 8,d_nr = 6; + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); + } + #endif + + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + + if (n_threads < 0 ) n_threads = 1; + + err_t status = BLIS_SUCCESS; + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + + // Compute start and end indexes of matrix partitioning for each thread + if ( bli_is_right( side ) ) + { + bli_thread_range_sub ( &thread, + m, + d_mr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + else + { + bli_thread_range_sub ( &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small + ( + side, + alpha, + a, + &b_t, + NULL, + NULL + ); + // To capture the error populated from any of the threads + _Pragma( "omp critical" ) + status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status; + } + + return status; +}// End of function +#endif + /* * ZTRSM utilities and kernel functions */ @@ -6105,7 +6232,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -8565,7 +8692,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -10909,7 +11036,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B //pointers that point to blocks for GEMM and TRSM double *a10, *a11, *b01, *b11; @@ -12889,7 +13016,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index ff97ca9ea..4bba0b22f 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -321,7 +321,7 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); - err_t bli_trsm_small +err_t bli_trsm_small ( side_t side, obj_t* alpha, @@ -331,6 +331,18 @@ void bli_dgemm_ref_k1_nn cntl_t* cntl ); +#ifdef BLIS_ENABLE_OPENMP +err_t bli_trsm_small_mt + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +#endif + // threshold functions bool bli_cntx_gemmtsup_thresh_is_met_zen (