STRSM small kernel implementation

Details:
-- AMD Internal Id: [CPUPL-1702]
-- Used 16x6 SGEMM kernel with vector fma by utilizing ymm registers
-- Used packing of matrix A to effectively cache and reuse
-- Implemented kernels using macro based modular approach
-- Taken care of --disable_pre_inversion configuration
-- modularized strsm 16 combinations of trsm into 4 kernels

Change-Id: I30a1551967c36f6bae33be3b7ae5b7fcc7c905ea
This commit is contained in:
satish kumar nuggu
2021-09-24 20:51:22 +05:30
committed by Dipal M Zambare
parent a3d04a21a0
commit 23278627f4
2 changed files with 24227 additions and 5585 deletions

View File

@@ -381,6 +381,264 @@ void PASTEF77(ch,blasname) \
#ifdef BLIS_ENABLE_BLAS
#ifdef BLIS_CONFIG_EPYC
void strsm_
(
const f77_char* side,
const f77_char* uploa,
const f77_char* transa,
const f77_char* diaga,
const f77_int* m,
const f77_int* n,
const float* alpha,
const float* a, const f77_int* lda,
float* b, const f77_int* ldb
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd',
*side, *uploa,*transa, *diaga, *m, *n,
(void*)alpha,*lda, *ldb);
side_t blis_side;
uplo_t blis_uploa;
trans_t blis_transa;
diag_t blis_diaga;
dim_t m0, n0;
conj_t conja = BLIS_NO_CONJUGATE ;
/* Initialize BLIS. */
bli_init_auto();
/* Perform BLAS parameter checking. */
PASTEBLACHK(trsm)
(
MKSTR(s),
MKSTR(trsm),
side,
uploa,
transa,
diaga,
m,
n,
lda,
ldb
);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_side( *side, &blis_side );
bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa );
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga );
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1( *m, m0 );
bli_convert_blas_dim1( *n, n0 );
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const num_t dt = BLIS_FLOAT;
if( n0 == 1 )
{
if( blis_side == BLIS_LEFT )
{
if(bli_is_notrans(blis_transa))
{
bli_strsv_unf_var2
(
blis_uploa,
blis_transa,
blis_diaga,
m0,
(float*)alpha,
(float*)a, rs_a, cs_a,
(float*)b, rs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
else if(bli_is_trans(blis_transa))
{
bli_strsv_unf_var1
(
blis_uploa,
blis_transa,
blis_diaga,
m0,
(float*)alpha,
(float*)a, rs_a, cs_a,
(float*)b, rs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) )
{
/* b = alpha * b; */
bli_sscalv_ex
(
conja,
m0,
(float*)alpha,
b, rs_b,
NULL,
NULL
);
if(blis_diaga == BLIS_NONUNIT_DIAG)
{
float inva = 1.0/ *a;
for(int indx = 0; indx < m0; indx ++)
{
b[indx] = ( inva * b[indx] );
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if( m0 == 1 )
{
if(blis_side == BLIS_RIGHT)
{
if(bli_is_notrans(blis_transa))
{
if(blis_uploa == BLIS_UPPER)
blis_uploa = BLIS_LOWER;
else
blis_uploa = BLIS_UPPER;
bli_strsv_unf_var1
(
blis_uploa,
blis_transa,
blis_diaga,
n0,
(float*)alpha,
(float*)a, cs_a, rs_a,
(float*)b, cs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
else if(bli_is_trans(blis_transa))
{
if(blis_uploa == BLIS_UPPER)
blis_uploa = BLIS_LOWER;
else
blis_uploa = BLIS_UPPER;
bli_strsv_unf_var2
(
blis_uploa,
blis_transa,
blis_diaga,
n0,
(float*)alpha,
(float*)a, cs_a, rs_a,
(float*)b, cs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 ))
{
/* b = alpha * b; */
bli_sscalv_ex
(
conja,
n0,
(float*)alpha,
b, cs_b,
NULL,
NULL
);
if(blis_diaga == BLIS_NONUNIT_DIAG)
{
float inva = 1.0/ *a;
for(int indx = 0; indx < n0; indx ++)
{
b[indx*cs_b] = (inva * b[indx*cs_b] );
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
const struc_t struca = BLIS_TRIANGULAR;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
dim_t mn0_a;
bli_set_dim_with_side( blis_side, m0, n0, &mn0_a );
bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao );
bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo );
bli_obj_set_uplo( blis_uploa, &ao );
bli_obj_set_diag( blis_diaga, &ao );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_struc( struca, &ao );
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
/* bli_strsm_small is performing better existing native
* implementations for [m,n]<=1000 for single thread.
* In case of multithread when [m,n]<=128 sinlge thread implemenation
* is doing better than native multithread */
bool nt = bli_thread_get_is_parallel();
if((nt==0 && m0<=1000 && n0<=1000) ||
(nt && (m0+n0)<320) )
{
err_t status;
status = bli_trsm_small
(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL
);
if (status == BLIS_SUCCESS)
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
#endif
bli_trsmnat
(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO)
/* Finalize BLIS. */
bli_finalize_auto();
}
void dtrsm_
(
const f77_char* side,
@@ -662,7 +920,7 @@ void ztrsm_
trans_t blis_transa;
diag_t blis_diaga;
dim_t m0, n0;
conj_t conja = BLIS_NO_CONJUGATE ;
//conj_t conja = BLIS_NO_CONJUGATE ;
/* Initialize BLIS. */
bli_init_auto();
@@ -937,8 +1195,6 @@ void ztrsm_
bli_finalize_auto();
}
GENTFUNC( float, s, trsm, trsm )
GENTFUNC( scomplex, c, trsm, trsm )
#else
INSERT_GENTFUNC_BLAS( trsm, trsm )

File diff suppressed because it is too large Load Diff