From aa9f5b8b3733928a2d47a12939ab85106146eb4b Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Thu, 27 May 2021 21:40:25 +0530 Subject: [PATCH] SUP GEMM - Enable only block panel (var2m) Completely disabling supvar1n (Panel Block) gemm to simplify things supvar1n perform better only when m >> and n=k=small (<10). This simplification will improve performance for m = n shape dgemm. Change-Id: I523fcb211e8ab92718ea7367f9707a38275e24b1 --- frame/3/bli_l3_sup_int.c | 96 +++++++++++++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 2869d76ec..8d2d91945 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -47,9 +47,20 @@ err_t bli_gemmsup_int ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); -// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c); -#if 0 +#ifdef BLIS_CONFIG_EPYC + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + dim_t jc_new; + dim_t ic_new; + + //bli_gemmsup_ref_var2 //bli_gemmsup_ref_var1 #if 0 @@ -61,23 +72,82 @@ err_t bli_gemmsup_int stor_id == BLIS_RRC || stor_id == BLIS_RCR || stor_id == BLIS_CRR ); + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + if ( is_rrr_rrc_rcr_crr ) { - bli_gemmsup_ref_var2m - ( - BLIS_NO_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm - ); + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + // - Currently only row-preferential kernels are only supported. + + // calculate number of micropanels in m and n dimensions and + // recalculate the automatic thread factorization based on these number of micropanels + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); } else { - bli_gemmsup_ref_var2m - ( - BLIS_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm - ); + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + // - Currently only row-preferential kernels are only supported. + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); return BLIS_SUCCESS; -#endif + +#else // #ifdef BLIS_CONFIG_EPYC const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); @@ -92,6 +162,7 @@ err_t bli_gemmsup_int stor_id == BLIS_RRC || stor_id == BLIS_RCR || stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; const num_t dt = bli_obj_dt( c ); @@ -110,7 +181,6 @@ err_t bli_gemmsup_int dim_t jc_new; dim_t ic_new; - if ( is_primary ) { // This branch handles: @@ -252,6 +322,8 @@ err_t bli_gemmsup_int // Return success so that the caller knows that we computed the solution. AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return BLIS_SUCCESS; + +#endif } // -----------------------------------------------------------------------------