mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Parallelization of dtrsm_small routine
1. Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right).
2. Fine-tuning with AOCL_DYNAMIC to achieve better performance.
AMD-Internal: [CPUPL-2103]
Change-Id: I6be6a2b579de7df9a3141e0d68bdf3e8a869a005
This commit is contained in:
committed by
Dipal M Zambare
parent
8e6da6b844
commit
1a3428ddfc
@@ -631,13 +631,22 @@ void bli_nthreads_optimum(
|
||||
else
|
||||
n_threads_ideal = n_threads;
|
||||
}
|
||||
else if( family == BLIS_TRSM && bli_obj_is_double(c))
|
||||
else if( family == BLIS_TRSM && bli_obj_is_double(c) )
|
||||
{
|
||||
dim_t m = bli_obj_length(c);
|
||||
dim_t n = bli_obj_width(c);
|
||||
|
||||
if(m<=512 && n<=512)
|
||||
n_threads_ideal = 4;
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
if ( (m <= 300) && (n <= 300) )
|
||||
n_threads_ideal = 8;
|
||||
else if ( (m <= 400) && (n <= 400) )
|
||||
n_threads_ideal = 16;
|
||||
else if ( (m <= 900) && (n <= 900) )
|
||||
n_threads_ideal = 32;
|
||||
#else
|
||||
if ( (m <= 512) && (n <= 512) )
|
||||
n_threads_ideal = 4;
|
||||
#endif
|
||||
}
|
||||
else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c))
|
||||
{
|
||||
|
||||
@@ -395,7 +395,7 @@ void strsm_
|
||||
)
|
||||
{
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
|
||||
AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd',
|
||||
AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's',
|
||||
*side, *uploa,*transa, *diaga, *m, *n,
|
||||
(void*)alpha,*lda, *ldb);
|
||||
|
||||
@@ -886,8 +886,45 @@ void dtrsm_
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//bli_trsm_small_mt is performing better than native multithread
|
||||
//for certain sizes of m & n.
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
rntm_t rntm;
|
||||
bli_rntm_init_from_global( &rntm );
|
||||
|
||||
// Query the total number of threads from the rntm_t object.
|
||||
dim_t n_threads = bli_rntm_num_threads( &rntm );
|
||||
if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) ||
|
||||
( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) ||
|
||||
( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) ||
|
||||
( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) ||
|
||||
( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) ||
|
||||
( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) )
|
||||
{
|
||||
err_t status;
|
||||
status = bli_trsm_small_mt
|
||||
(
|
||||
blis_side,
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
|
||||
if ( status == BLIS_SUCCESS )
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif// BLIS_ENABLE_OPENMP
|
||||
#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
}
|
||||
|
||||
bli_trsmnat
|
||||
(
|
||||
blis_side,
|
||||
|
||||
@@ -3821,15 +3821,22 @@ err_t bli_trsm_small
|
||||
num_t dt = bli_obj_dt(a);
|
||||
switch(dt)
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
case BLIS_FLOAT:
|
||||
case BLIS_SCOMPLEX:
|
||||
{
|
||||
if(m > 1000 || n > 1000) {
|
||||
case BLIS_DOUBLE:
|
||||
{
|
||||
bool nt = bli_thread_get_is_parallel();
|
||||
if((nt == 0) && (m > 1000 || n > 1000)) {
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BLIS_FLOAT:
|
||||
case BLIS_SCOMPLEX:
|
||||
{
|
||||
if(m > 1000 || n > 1000) {
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
case BLIS_DCOMPLEX:
|
||||
{
|
||||
if(m > 500 || n > 500) {
|
||||
@@ -3886,6 +3893,126 @@ err_t bli_trsm_small
|
||||
return err;
|
||||
};
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
/*
|
||||
* Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right)
|
||||
*/
|
||||
|
||||
err_t bli_trsm_small_mt
|
||||
(
|
||||
side_t side,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
)
|
||||
{
|
||||
rntm_t rntm;
|
||||
gint_t m = bli_obj_length( b ); // number of rows of matrix b
|
||||
gint_t n = bli_obj_width( b ); // number of columns of Matrix b
|
||||
dim_t d_mr = 8,d_nr = 6;
|
||||
|
||||
num_t dt = bli_obj_dt(a);
|
||||
switch(dt)
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
{
|
||||
d_mr = 8,d_nr = 6;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef AOCL_DYNAMIC
|
||||
// If dynamic-threading is enabled, calculate optimum number
|
||||
// of threads.
|
||||
// rntm will be updated with optimum number of threads.
|
||||
if( bli_obj_is_double(b))
|
||||
{
|
||||
bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm);
|
||||
}
|
||||
#endif
|
||||
|
||||
bli_rntm_init_from_global( &rntm );
|
||||
|
||||
// Query the total number of threads from the rntm_t object.
|
||||
dim_t n_threads = bli_rntm_num_threads( &rntm );
|
||||
|
||||
if (n_threads < 0 ) n_threads = 1;
|
||||
|
||||
err_t status = BLIS_SUCCESS;
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
// Query the thread's id from OpenMP.
|
||||
const dim_t tid = omp_get_thread_num();
|
||||
|
||||
obj_t b_t;
|
||||
dim_t start; // Each thread start Index
|
||||
dim_t end; // Each thread end Index
|
||||
thrinfo_t thread;
|
||||
|
||||
thread.n_way = n_threads;
|
||||
thread.work_id = tid;
|
||||
thread.ocomm_id = tid;
|
||||
|
||||
|
||||
// Compute start and end indexes of matrix partitioning for each thread
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
bli_thread_range_sub ( &thread,
|
||||
m,
|
||||
d_mr,// Need to decide based on type
|
||||
FALSE,
|
||||
&start,
|
||||
&end
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
|
||||
bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_thread_range_sub ( &thread,
|
||||
n,
|
||||
d_nr,// Need to decide based on type
|
||||
FALSE,
|
||||
&start,
|
||||
&end
|
||||
);
|
||||
// For each thread acquire matrix block on which they operate
|
||||
// Data-based parallelism
|
||||
|
||||
bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t);
|
||||
}
|
||||
|
||||
// Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to
|
||||
// all threads
|
||||
err_t status_l = BLIS_SUCCESS;
|
||||
|
||||
status_l = bli_trsm_small
|
||||
(
|
||||
side,
|
||||
alpha,
|
||||
a,
|
||||
&b_t,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
// To capture the error populated from any of the threads
|
||||
_Pragma( "omp critical" )
|
||||
status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}// End of function
|
||||
#endif
|
||||
|
||||
/*
|
||||
* ZTRSM utilities and kernel functions
|
||||
*/
|
||||
@@ -6105,7 +6232,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB
|
||||
|
||||
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
||||
double* restrict L = a->buffer; //pointer to matrix A
|
||||
double* restrict B = b->buffer; //pointer to matrix B
|
||||
double *B = bli_obj_buffer_at_off(b); //pointer to matrix B
|
||||
|
||||
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
||||
|
||||
@@ -8565,7 +8692,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB
|
||||
|
||||
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
||||
double* restrict L = a->buffer; //pointer to matrix A
|
||||
double* restrict B = b->buffer; //pointer to matrix B
|
||||
double *B = bli_obj_buffer_at_off(b); //pointer to matrix B
|
||||
|
||||
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
||||
|
||||
@@ -10909,7 +11036,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB
|
||||
|
||||
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
||||
double *L = a->buffer; //pointer to matrix A
|
||||
double *B = b->buffer; //pointer to matrix B
|
||||
double *B = bli_obj_buffer_at_off(b); //pointer to matrix B
|
||||
|
||||
//pointers that point to blocks for GEMM and TRSM
|
||||
double *a10, *a11, *b01, *b11;
|
||||
@@ -12889,7 +13016,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB
|
||||
|
||||
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
||||
double *L = a->buffer; //pointer to matrix A
|
||||
double *B = b->buffer; //pointer to matrix B
|
||||
double *B = bli_obj_buffer_at_off(b); //pointer to matrix B
|
||||
|
||||
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
|
||||
|
||||
|
||||
@@ -321,7 +321,7 @@ void bli_dgemm_ref_k1_nn
|
||||
double* c, const inc_t ldc
|
||||
);
|
||||
|
||||
err_t bli_trsm_small
|
||||
err_t bli_trsm_small
|
||||
(
|
||||
side_t side,
|
||||
obj_t* alpha,
|
||||
@@ -331,6 +331,18 @@ void bli_dgemm_ref_k1_nn
|
||||
cntl_t* cntl
|
||||
);
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
err_t bli_trsm_small_mt
|
||||
(
|
||||
side_t side,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
);
|
||||
#endif
|
||||
|
||||
// threshold functions
|
||||
bool bli_cntx_gemmtsup_thresh_is_met_zen
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user