mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Enable DTRSM small multithreading path for BLAS interface
- Enabled DTRSM small mt for sizes where performance is better than small or native. - Threshold Tuning for small path is updated. - Function signature for bli_trsm_small_mt has been made similar to bli_trsm_small so that one function pointer can be used for all functions. - Early return condition in DTRSM small for sizes > 1000 has been removed so that the sizes for which small path to take can be decided on bla layer instead of inside kernel. AMD-Internal: [CPUPL-2735] Change-Id: Ieea31343dc660517acc18c92713381a8b84d3a2f
This commit is contained in:
@@ -56,7 +56,7 @@ void PASTEMAC(ch,opname) \
|
||||
|
||||
|
||||
|
||||
#define TRSMSUP_PROT( opname ) \
|
||||
#define TRSMSMALL_PROT( opname ) \
|
||||
\
|
||||
err_t PASTEMAC0(opname) \
|
||||
( \
|
||||
@@ -69,7 +69,7 @@ err_t PASTEMAC0(opname) \
|
||||
bool is_parallel \
|
||||
);
|
||||
|
||||
#define TRSMSUP_KER_PROT( ch, opname ) \
|
||||
#define TRSMSMALL_KER_PROT( ch, opname ) \
|
||||
\
|
||||
BLIS_INLINE err_t PASTEMAC(ch,opname) \
|
||||
( \
|
||||
|
||||
@@ -941,81 +941,113 @@ void dtrsm_blis_impl
|
||||
bli_obj_set_struc( struca, &ao );
|
||||
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
|
||||
// This function is invoked on all architectures including ‘generic’.
|
||||
// Non-AVX platforms will use the kernels derived from the context.
|
||||
if (bli_cpuid_is_avx_supported() == TRUE)
|
||||
{
|
||||
/* bli_dtrsm_small is performing better existing native
|
||||
* implementations for [m,n]<=1000 for single thread.
|
||||
* In case of multithread when [m,n]<=128 single thread implementation
|
||||
* is doing better than native multithread */
|
||||
// typedef for trsm small kernel function pointer
|
||||
typedef err_t (*dtrsm_small_ker_ft)
|
||||
(
|
||||
side_t side,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
bool is_parallel
|
||||
);
|
||||
err_t status = BLIS_NOT_YET_IMPLEMENTED;
|
||||
|
||||
// trsm small kernel function pointer definition
|
||||
dtrsm_small_ker_ft ker_ft = NULL;
|
||||
|
||||
// Query the architecture ID
|
||||
arch_t id = bli_arch_query_id();
|
||||
|
||||
// dimensions of triangular matrix
|
||||
// for left variants, dim_a is m0,
|
||||
// for right variants, dim_a is n0
|
||||
dim_t dim_a = n0;
|
||||
if (blis_side == BLIS_LEFT)
|
||||
dim_a = m0;
|
||||
|
||||
// size of output matrix(B)
|
||||
dim_t size_b = m0*n0;
|
||||
|
||||
/* bli_dtrsm_small is performing better than existing native
|
||||
* implementations for dim_a<1500 and m0*n0<5e6 for single thread.
|
||||
* In case of multithread when [m+n]<320 single thread implementation
|
||||
* is doing better than small multithread and native multithread */
|
||||
bool is_parallel = bli_thread_get_is_parallel();
|
||||
if ((!is_parallel && m0<=1000 && n0<=1000) ||
|
||||
if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) ||
|
||||
(is_parallel && (m0+n0)<320))
|
||||
{
|
||||
err_t status;
|
||||
|
||||
// Query the architecture ID
|
||||
arch_t id = bli_arch_query_id();
|
||||
switch(id)
|
||||
{
|
||||
case BLIS_ARCH_ZEN4:
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
// check if variant is RUN[N/U] or RLT[N/U]
|
||||
// this is a temporary fix, will be removed when all variants are added
|
||||
|
||||
// for n < 200 avx2 kernels are performing better, but if
|
||||
// n is a multiple of 8 then there will be no fringe case for avx512,
|
||||
// in such cases avx512 kernels will perform better.
|
||||
if( (blis_side == BLIS_RIGHT) &&
|
||||
((n0 > 300) && (m0 > 50)))
|
||||
{
|
||||
status = bli_trsm_small_AVX512(
|
||||
blis_side,
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel);
|
||||
ker_ft = bli_trsm_small_AVX512;
|
||||
}
|
||||
else
|
||||
{
|
||||
status = bli_trsm_small(
|
||||
blis_side,
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel);
|
||||
ker_ft = bli_trsm_small;
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
|
||||
#endif // BLIS_KERNELS_ZEN4
|
||||
case BLIS_ARCH_ZEN:
|
||||
case BLIS_ARCH_ZEN2:
|
||||
case BLIS_ARCH_ZEN3:
|
||||
status = bli_trsm_small(
|
||||
blis_side,
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel);
|
||||
break;
|
||||
default:
|
||||
status = BLIS_NOT_YET_IMPLEMENTED;
|
||||
ker_ft = bli_trsm_small;
|
||||
break;
|
||||
}
|
||||
if (status == BLIS_SUCCESS)
|
||||
}
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
if( (ker_ft == NULL) && (is_parallel) &&
|
||||
((dim_a < 2500) && (size_b < 5e6)) )
|
||||
{
|
||||
switch(id)
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
case BLIS_ARCH_ZEN4:
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
if( (blis_side == BLIS_RIGHT) )
|
||||
{
|
||||
ker_ft = bli_trsm_small_mt_AVX512;
|
||||
}
|
||||
else
|
||||
{
|
||||
ker_ft = bli_trsm_small_mt;
|
||||
}
|
||||
break;
|
||||
#endif// BLIS_KERNELS_ZEN4
|
||||
case BLIS_ARCH_ZEN:
|
||||
case BLIS_ARCH_ZEN2:
|
||||
case BLIS_ARCH_ZEN3:
|
||||
default:
|
||||
ker_ft = bli_trsm_small_mt;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#endif// BLIS_ENABLE_OPENMP
|
||||
if(ker_ft)
|
||||
{
|
||||
status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel);
|
||||
}
|
||||
if (status == BLIS_SUCCESS)
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
} // bli_cpuid_is_avx_supported
|
||||
#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
|
||||
|
||||
@@ -5264,9 +5264,6 @@ err_t bli_trsm_small
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
{
|
||||
if((!is_parallel) && (m > 1000 || n > 1000)) {
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BLIS_FLOAT:
|
||||
@@ -5345,7 +5342,8 @@ err_t bli_trsm_small_mt
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
cntl_t* cntl,
|
||||
bool is_parallel
|
||||
)
|
||||
{
|
||||
gint_t m = bli_obj_length( b ); // number of rows of matrix b
|
||||
@@ -5390,73 +5388,97 @@ err_t bli_trsm_small_mt
|
||||
|
||||
if (n_threads < 0 ) n_threads = 1;
|
||||
|
||||
bool is_parallel = bli_thread_get_is_parallel();
|
||||
|
||||
err_t status = BLIS_SUCCESS;
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
// Query the thread's id from OpenMP.
|
||||
const dim_t tid = omp_get_thread_num();
|
||||
const dim_t nt_real = omp_get_num_threads();
|
||||
|
||||
obj_t b_t;
|
||||
dim_t start; // Each thread start Index
|
||||
dim_t end; // Each thread end Index
|
||||
thrinfo_t thread;
|
||||
|
||||
thread.n_way = n_threads;
|
||||
thread.work_id = tid;
|
||||
thread.ocomm_id = tid;
|
||||
|
||||
|
||||
// Compute start and end indexes of matrix partitioning for each thread
|
||||
if ( bli_is_right( side ) )
|
||||
if(nt_real != n_threads)
|
||||
{
|
||||
bli_thread_range_sub ( &thread,
|
||||
if(tid == 0)
|
||||
{
|
||||
bli_trsm_small
|
||||
(
|
||||
side,
|
||||
alpha,
|
||||
a,
|
||||
b,
|
||||
cntx,
|
||||
cntl,
|
||||
is_parallel
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
obj_t b_t;
|
||||
dim_t start; // Each thread start Index
|
||||
dim_t end; // Each thread end Index
|
||||
thrinfo_t thread;
|
||||
|
||||
thread.n_way = n_threads;
|
||||
thread.work_id = tid;
|
||||
thread.ocomm_id = tid;
|
||||
|
||||
|
||||
// Compute start and end indexes of matrix partitioning for each thread
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
bli_thread_range_sub
|
||||
(
|
||||
&thread,
|
||||
m,
|
||||
d_mr,// Need to decide based on type
|
||||
FALSE,
|
||||
&start,
|
||||
&end
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
|
||||
bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_thread_range_sub
|
||||
(
|
||||
&thread,
|
||||
n,
|
||||
d_nr,// Need to decide based on type
|
||||
FALSE,
|
||||
&start,
|
||||
&end
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
|
||||
bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
}
|
||||
|
||||
// Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to
|
||||
// all threads
|
||||
err_t status_l = BLIS_SUCCESS;
|
||||
|
||||
status_l = bli_trsm_small
|
||||
(
|
||||
side,
|
||||
alpha,
|
||||
a,
|
||||
&b_t,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel
|
||||
);
|
||||
// To capture the error populated from any of the threads
|
||||
if ( status_l != BLIS_SUCCESS )
|
||||
{
|
||||
_Pragma("omp critical")
|
||||
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_thread_range_sub ( &thread,
|
||||
n,
|
||||
d_nr,// Need to decide based on type
|
||||
FALSE,
|
||||
&start,
|
||||
&end
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
|
||||
bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
}
|
||||
|
||||
// Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to
|
||||
// all threads
|
||||
err_t status_l = BLIS_SUCCESS;
|
||||
|
||||
status_l = bli_trsm_small
|
||||
(
|
||||
side,
|
||||
alpha,
|
||||
a,
|
||||
&b_t,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel
|
||||
);
|
||||
// To capture the error populated from any of the threads
|
||||
_Pragma( "omp critical" )
|
||||
status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}// End of function
|
||||
#endif
|
||||
|
||||
@@ -345,7 +345,8 @@ err_t bli_trsm_small_mt
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
cntl_t* cntl,
|
||||
bool is_parallel
|
||||
);
|
||||
|
||||
void bli_multi_sgemv_4x2
|
||||
|
||||
@@ -416,8 +416,6 @@ err_t bli_trsm_small_AVX512
|
||||
)
|
||||
{
|
||||
err_t err;
|
||||
dim_t m = bli_obj_length(b);
|
||||
dim_t n = bli_obj_width(b);
|
||||
|
||||
bool uplo = bli_obj_is_upper(a);
|
||||
bool transa = bli_obj_has_trans(a);
|
||||
@@ -427,10 +425,6 @@ err_t bli_trsm_small_AVX512
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
{
|
||||
if ((!is_parallel) && (m > 1200 || n > 1200))
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BLIS_FLOAT:
|
||||
@@ -490,7 +484,8 @@ err_t bli_trsm_small_mt_AVX512
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
cntl_t* cntl,
|
||||
bool is_parallel
|
||||
)
|
||||
{
|
||||
gint_t m = bli_obj_length(b); // number of rows of matrix b
|
||||
@@ -531,8 +526,6 @@ err_t bli_trsm_small_mt_AVX512
|
||||
if (n_threads < 0)
|
||||
n_threads = 1;
|
||||
|
||||
bool is_parallel = bli_thread_get_is_parallel();
|
||||
|
||||
err_t status = BLIS_SUCCESS;
|
||||
_Pragma("omp parallel num_threads(n_threads)")
|
||||
{
|
||||
@@ -546,7 +539,7 @@ err_t bli_trsm_small_mt_AVX512
|
||||
{
|
||||
if(tid == 0)
|
||||
{
|
||||
bli_trsm_small
|
||||
bli_trsm_small_AVX512
|
||||
(
|
||||
side,
|
||||
alpha,
|
||||
@@ -618,8 +611,11 @@ err_t bli_trsm_small_mt_AVX512
|
||||
is_parallel
|
||||
);
|
||||
// To capture the error populated from any of the threads
|
||||
_Pragma("omp critical")
|
||||
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
|
||||
if ( status_l != BLIS_SUCCESS )
|
||||
{
|
||||
_Pragma("omp critical")
|
||||
status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -110,8 +110,12 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_3x32_avx512 )
|
||||
GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_2x32_avx512 )
|
||||
GEMMSUP_KER_PROT( float, s, gemmsup_rd_zen_asm_1x32_avx512 )
|
||||
|
||||
TRSMSUP_PROT(trsm_small_AVX512)
|
||||
TRSMSUP_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
|
||||
TRSMSUP_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
|
||||
TRSMSUP_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
|
||||
TRSMSUP_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
|
||||
TRSMSMALL_PROT(trsm_small_AVX512)
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
TRSMSMALL_PROT(trsm_small_mt_AVX512)
|
||||
#endif
|
||||
Reference in New Issue
Block a user