Enable DTRSM small multithreading path for BLAS interface

- Enabled DTRSM small mt for sizes where performance is better
  than small or native.
- Threshold Tuning for small path is updated.
- Function signature for bli_trsm_small_mt has been made similar
  to bli_trsm_small so that one function pointer can be used for
  all functions.
- Early return condition in DTRSM small for sizes > 1000 has been
  removed so that the sizes for which small path to take can be
  decided on bla layer instead of inside kernel.

AMD-Internal: [CPUPL-2735]
Change-Id: Ieea31343dc660517acc18c92713381a8b84d3a2f
This commit is contained in:
Shubham
2023-03-07 02:51:33 +05:30
committed by Shubham Sharma
parent 873b4f93fd
commit dfc95d29fc
6 changed files with 177 additions and 122 deletions

View File

@@ -56,7 +56,7 @@ void PASTEMAC(ch,opname) \
#define TRSMSUP_PROT( opname ) \
#define TRSMSMALL_PROT( opname ) \
\
err_t PASTEMAC0(opname) \
( \
@@ -69,7 +69,7 @@ err_t PASTEMAC0(opname) \
bool is_parallel \
);
#define TRSMSUP_KER_PROT( ch, opname ) \
#define TRSMSMALL_KER_PROT( ch, opname ) \
\
BLIS_INLINE err_t PASTEMAC(ch,opname) \
( \

View File

@@ -941,81 +941,113 @@ void dtrsm_blis_impl
bli_obj_set_struc( struca, &ao );
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
// This function is invoked on all architectures including generic.
// Non-AVX platforms will use the kernels derived from the context.
if (bli_cpuid_is_avx_supported() == TRUE)
{
/* bli_dtrsm_small is performing better existing native
* implementations for [m,n]<=1000 for single thread.
* In case of multithread when [m,n]<=128 single thread implementation
* is doing better than native multithread */
// typedef for trsm small kernel function pointer
typedef err_t (*dtrsm_small_ker_ft)
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl,
bool is_parallel
);
err_t status = BLIS_NOT_YET_IMPLEMENTED;
// trsm small kernel function pointer definition
dtrsm_small_ker_ft ker_ft = NULL;
// Query the architecture ID
arch_t id = bli_arch_query_id();
// dimensions of triangular matrix
// for left variants, dim_a is m0,
// for right variants, dim_a is n0
dim_t dim_a = n0;
if (blis_side == BLIS_LEFT)
dim_a = m0;
// size of output matrix(B)
dim_t size_b = m0*n0;
/* bli_dtrsm_small is performing better than existing native
* implementations for dim_a<1500 and m0*n0<5e6 for single thread.
* In case of multithread when [m+n]<320 single thread implementation
* is doing better than small multithread and native multithread */
bool is_parallel = bli_thread_get_is_parallel();
if ((!is_parallel && m0<=1000 && n0<=1000) ||
if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) ||
(is_parallel && (m0+n0)<320))
{
err_t status;
// Query the architecture ID
arch_t id = bli_arch_query_id();
switch(id)
{
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// check if variant is RUN[N/U] or RLT[N/U]
// this is a temporary fix, will be removed when all variants are added
// for n < 200 avx2 kernels are performing better, but if
// n is a multiple of 8 then there will be no fringe case for avx512,
// in such cases avx512 kernels will perform better.
if( (blis_side == BLIS_RIGHT) &&
((n0 > 300) && (m0 > 50)))
{
status = bli_trsm_small_AVX512(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL,
is_parallel);
ker_ft = bli_trsm_small_AVX512;
}
else
{
status = bli_trsm_small(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL,
is_parallel);
ker_ft = bli_trsm_small;
}
break;
#endif
#endif // BLIS_KERNELS_ZEN4
case BLIS_ARCH_ZEN:
case BLIS_ARCH_ZEN2:
case BLIS_ARCH_ZEN3:
status = bli_trsm_small(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL,
is_parallel);
break;
default:
status = BLIS_NOT_YET_IMPLEMENTED;
ker_ft = bli_trsm_small;
break;
}
if (status == BLIS_SUCCESS)
}
#ifdef BLIS_ENABLE_OPENMP
if( (ker_ft == NULL) && (is_parallel) &&
((dim_a < 2500) && (size_b < 5e6)) )
{
switch(id)
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
if( (blis_side == BLIS_RIGHT) )
{
ker_ft = bli_trsm_small_mt_AVX512;
}
else
{
ker_ft = bli_trsm_small_mt;
}
break;
#endif// BLIS_KERNELS_ZEN4
case BLIS_ARCH_ZEN:
case BLIS_ARCH_ZEN2:
case BLIS_ARCH_ZEN3:
default:
ker_ft = bli_trsm_small_mt;
break;
}
}
#endif// BLIS_ENABLE_OPENMP
if(ker_ft)
{
status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel);
}
if (status == BLIS_SUCCESS)
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
} // bli_cpuid_is_avx_supported
#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM

View File

@@ -5264,9 +5264,6 @@ err_t bli_trsm_small
{
case BLIS_DOUBLE:
{
if((!is_parallel) && (m > 1000 || n > 1000)) {
return BLIS_NOT_YET_IMPLEMENTED;
}
break;
}
case BLIS_FLOAT:
@@ -5345,7 +5342,8 @@ err_t bli_trsm_small_mt
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
cntl_t* cntl,
bool is_parallel
)
{
gint_t m = bli_obj_length( b ); // number of rows of matrix b
@@ -5390,73 +5388,97 @@ err_t bli_trsm_small_mt
if (n_threads < 0 ) n_threads = 1;
bool is_parallel = bli_thread_get_is_parallel();
err_t status = BLIS_SUCCESS;
_Pragma( "omp parallel num_threads(n_threads)" )
{
// Query the thread's id from OpenMP.
const dim_t tid = omp_get_thread_num();
const dim_t nt_real = omp_get_num_threads();
obj_t b_t;
dim_t start; // Each thread start Index
dim_t end; // Each thread end Index
thrinfo_t thread;
thread.n_way = n_threads;
thread.work_id = tid;
thread.ocomm_id = tid;
// Compute start and end indexes of matrix partitioning for each thread
if ( bli_is_right( side ) )
if(nt_real != n_threads)
{
bli_thread_range_sub ( &thread,
if(tid == 0)
{
bli_trsm_small
(
side,
alpha,
a,
b,
cntx,
cntl,
is_parallel
);
}
}
else
{
obj_t b_t;
dim_t start; // Each thread start Index
dim_t end; // Each thread end Index
thrinfo_t thread;
thread.n_way = n_threads;
thread.work_id = tid;
thread.ocomm_id = tid;
// Compute start and end indexes of matrix partitioning for each thread
if ( bli_is_right( side ) )
{
bli_thread_range_sub
(
&thread,
m,
d_mr,// Need to decide based on type
FALSE,
&start,
&end
);
// For each thread acquire matrix block on which they operate
// Data-based parallelism
);
// For each thread acquire matrix block on which they operate
// Data-based parallelism
bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
}
else
{
bli_thread_range_sub
(
&thread,
n,
d_nr,// Need to decide based on type
FALSE,
&start,
&end
);
// For each thread acquire matrix block on which they operate
// Data-based parallelism
bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
}
// Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to
// all threads
err_t status_l = BLIS_SUCCESS;
status_l = bli_trsm_small
(
side,
alpha,
a,
&b_t,
NULL,
NULL,
is_parallel
);
// To capture the error populated from any of the threads
if ( status_l != BLIS_SUCCESS )
{
_Pragma("omp critical")
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
}
}
else
{
bli_thread_range_sub ( &thread,
n,
d_nr,// Need to decide based on type
FALSE,
&start,
&end
);
// For each thread acquire matrix block on which they operate
// Data-based parallelism
bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
}
// Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to
// all threads
err_t status_l = BLIS_SUCCESS;
status_l = bli_trsm_small
(
side,
alpha,
a,
&b_t,
NULL,
NULL,
is_parallel
);
// To capture the error populated from any of the threads
_Pragma( "omp critical" )
status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status;
}
return status;
}// End of function
#endif

View File

@@ -345,7 +345,8 @@ err_t bli_trsm_small_mt
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
cntl_t* cntl,
bool is_parallel
);
void bli_multi_sgemv_4x2

View File

@@ -416,8 +416,6 @@ err_t bli_trsm_small_AVX512
)
{
err_t err;
dim_t m = bli_obj_length(b);
dim_t n = bli_obj_width(b);
bool uplo = bli_obj_is_upper(a);
bool transa = bli_obj_has_trans(a);
@@ -427,10 +425,6 @@ err_t bli_trsm_small_AVX512
{
case BLIS_DOUBLE:
{
if ((!is_parallel) && (m > 1200 || n > 1200))
{
return BLIS_NOT_YET_IMPLEMENTED;
}
break;
}
case BLIS_FLOAT:
@@ -490,7 +484,8 @@ err_t bli_trsm_small_mt_AVX512
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
cntl_t* cntl,
bool is_parallel
)
{
gint_t m = bli_obj_length(b); // number of rows of matrix b
@@ -531,8 +526,6 @@ err_t bli_trsm_small_mt_AVX512
if (n_threads < 0)
n_threads = 1;
bool is_parallel = bli_thread_get_is_parallel();
err_t status = BLIS_SUCCESS;
_Pragma("omp parallel num_threads(n_threads)")
{
@@ -546,7 +539,7 @@ err_t bli_trsm_small_mt_AVX512
{
if(tid == 0)
{
bli_trsm_small
bli_trsm_small_AVX512
(
side,
alpha,
@@ -618,8 +611,11 @@ err_t bli_trsm_small_mt_AVX512
is_parallel
);
// To capture the error populated from any of the threads
_Pragma("omp critical")
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
if ( status_l != BLIS_SUCCESS )
{
_Pragma("omp critical")
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
}
}
}

View File

@@ -110,8 +110,12 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_3x32_avx512 )
GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x32_avx512 )
GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x32_avx512 )
TRSMSUP_PROT(trsm_small_AVX512)
TRSMSUP_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
TRSMSUP_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
TRSMSUP_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
TRSMSUP_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
TRSMSMALL_PROT(trsm_small_AVX512)
TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
#ifdef BLIS_ENABLE_OPENMP
TRSMSMALL_PROT(trsm_small_mt_AVX512)
#endif