Modified blas interface of TRSM to call TRSV whenever m=1 or n=1.

TRSM API: AX = B, where X=B
  Case1: Call TRSV when matrix B is vector & A is matrix,
         When n = 1 for left side and when m = 1 for right side
  Case2: Divide B/A when matrix B is vector & A is scalar(Diagonal element),
         When m = 1 for left side and when n = 1 for right side
  For right side, Transpose complete operation, Change upper to lower and
                  vice versa when A is being transposed

Change-Id: Ib020f2a568f04a6e8d8f75bfc38adbfd7c5d175a
This commit is contained in:
managalv
2021-02-10 05:39:24 +05:30
parent 1ff4981203
commit 8face536fd
2 changed files with 28 additions and 47 deletions

View File

@@ -187,29 +187,30 @@ void PASTEF77(ch,blasname) \
const inc_t cs_b = *ldb; \
const num_t dt = PASTEMAC(ch,type); \
\
/* ---------------------------------------------------------- */ \
/* CALL TRSV when C & B are vector and when A is Matrix */ \
/* Case 1: LEFT : TRSM, C(mxn) = A(mxm) * B(mxn) */ \
/* Case 2: RIGHT : TRSM, C(mxn) = B(mxn) * A(nxn) */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | | C | A | B | Implementation | */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | LEFT | mxn | mxm | mxn | | */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | n = 1 | mx1 | mxm | mx1 | TRSV | */ \
/* | m = 1 | 1xn | 1x1 | 1xn | INVSCALS | */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | | C | B | A | Implementation | */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | Right | mxn | mxn | nxn | | */ \
/* |--------|-------|-------|-------|-----------------------| */ \
/* | n = 1 | mx1 | mx1 | 1x1 | Transpose and INVSCALS| */ \
/* | m = 1 | 1xn | 1xn | nxn | Transpose and TRSV | */ \
/* |----------------|-------|-------|-----------------------| */ \
/* If Transpose(A) uplo = lower then uplo = higher */ \
/* If Transpose(A) uplo = higher then uplo = lower */ \
/* ---------------------------------------------------------- */ \
/* ----------------------------------------------------------- */ \
/* TRSM API: AX = B, where X = B */ \
/* CALL TRSV when X & B are vector and when A is Matrix */ \
/* Case 1: LEFT : TRSM, C(mxn) = A(mxm) * B(mxn) */ \
/* Case 2: RIGHT : TRSM, C(mxn) = B(mxn) * A(nxn) */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | | A | X | B | Implementation | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | LEFT | mxm | mxn | mxn | | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | n = 1 | mxm | mx1 | mx1 | TRSV | */ \
/* | m = 1 | 1x1 | 1xn | 1xn | INVSCALS | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | | X | A | B | Implementation | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | RIGHT | mxn | nxn | mxn | | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* | n = 1 | mx1 | 1x1 | mx1 | Transpose and INVSCALS| */ \
/* | m = 1 | 1xn | nxn | 1xn | Transpose and TRSV | */ \
/* |--------|-------|-------|-------|------------------------| */ \
/* If Transpose(A) uplo = lower then uplo = higher */ \
/* If Transpose(A) uplo = higher then uplo = lower */ \
/* ----------------------------------------------------------- */ \
\
if( n0 == 1 ) \
{ \

View File

@@ -79,12 +79,7 @@ void bli_sdotxf_zen_int_8
if ( bli_zero_dim1( m ) || PASTEMAC(s,eq0)( *alpha ) )
{
#ifdef BLIS_CONFIG_EPYC
sscalv_ker_ft f = bli_sscalv_zen_int10;
#else
sscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SCALV_KER, cntx );
#endif
f
bli_sscalv_zen_int10
(
BLIS_NO_CONJUGATE,
b_n,
@@ -99,18 +94,13 @@ void bli_sdotxf_zen_int_8
// operation as a loop over dotxv.
if ( b_n != fuse_fac )
{
#ifdef BLIS_CONFIG_EPYC
sdotxv_ker_ft f = bli_sdotxv_zen_int;
#else
sdotxv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_DOTXV_KER, cntx );
#endif
for ( dim_t i = 0; i < b_n; ++i )
{
float* a1 = a + (0 )*inca + (i )*lda;
float* x1 = x + (0 )*incx;
float* psi1 = y + (i )*incy;
f
bli_sdotxv_zen_int
(
conjat,
conjx,
@@ -475,12 +465,7 @@ void bli_ddotxf_zen_int_8
// simplifies to updating y.
if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) )
{
#ifdef BLIS_CONFIG_EPYC
dscalv_ker_ft f = bli_dscalv_zen_int10;
#else
dscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SCALV_KER, cntx );
#endif
f
bli_dscalv_zen_int10
(
BLIS_NO_CONJUGATE,
b_n,
@@ -495,18 +480,13 @@ void bli_ddotxf_zen_int_8
// operation as a loop over dotxv.
if ( b_n != fuse_fac )
{
#ifdef BLIS_CONFIG_EPYC
ddotxv_ker_ft f = bli_ddotxv_zen_int;
#else
bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_DOTXV_KER, cntx );
#endif
for ( dim_t i = 0; i < b_n; ++i )
{
double* a1 = a + (0 )*inca + (i )*lda;
double* x1 = x + (0 )*incx;
double* psi1 = y + (i )*incy;
f
bli_ddotxv_zen_int
(
conjat,
conjx,