From 8face536fd97d0633e96cd49123891e1bbe4caf4 Mon Sep 17 00:00:00 2001 From: managalv Date: Wed, 10 Feb 2021 05:39:24 +0530 Subject: [PATCH] Modified blas interface of TRSM to call TRSV whenever m=1 or n=1. TRSM API: AX = B, where X=B Case1: Call TRSV when matrix B is vector & A is matrix, When n = 1 for left side and when m = 1 for right side Case2: Divide B/A when matrix B is vector & A is scalar(Diagonal element), When m = 1 for left side and when n = 1 for right side For right side, Transpose complete operation, Change upper to lower and vice versa when A is being transposed Change-Id: Ib020f2a568f04a6e8d8f75bfc38adbfd7c5d175a --- frame/compat/bla_trsm.c | 47 ++++++++++++++-------------- kernels/zen/1f/bli_dotxf_zen_int_8.c | 28 +++-------------- 2 files changed, 28 insertions(+), 47 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 8c8ff39be..fff59a351 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -187,29 +187,30 @@ void PASTEF77(ch,blasname) \ const inc_t cs_b = *ldb; \ const num_t dt = PASTEMAC(ch,type); \ \ - /* ---------------------------------------------------------- */ \ - /* CALL TRSV when C & B are vector and when A is Matrix */ \ - /* Case 1: LEFT : TRSM, C(mxn) = A(mxm) * B(mxn) */ \ - /* Case 2: RIGHT : TRSM, C(mxn) = B(mxn) * A(nxn) */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | | C | A | B | Implementation | */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | LEFT | mxn | mxm | mxn | | */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | n = 1 | mx1 | mxm | mx1 | TRSV | */ \ - /* | m = 1 | 1xn | 1x1 | 1xn | INVSCALS | */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | | C | B | A | Implementation | */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | Right | mxn | mxn | nxn | | */ \ - /* |--------|-------|-------|-------|-----------------------| */ \ - /* | n = 1 | mx1 | mx1 | 1x1 | Transpose and INVSCALS| */ \ - /* | m = 1 | 1xn | 1xn | nxn | Transpose and TRSV | */ \ - /* |----------------|-------|-------|-----------------------| */ \ - /* If Transpose(A) uplo = lower then uplo = higher */ \ - /* If Transpose(A) uplo = higher then uplo = lower */ \ - /* ---------------------------------------------------------- */ \ + /* ----------------------------------------------------------- */ \ + /* TRSM API: AX = B, where X = B */ \ + /* CALL TRSV when X & B are vector and when A is Matrix */ \ + /* Case 1: LEFT : TRSM, C(mxn) = A(mxm) * B(mxn) */ \ + /* Case 2: RIGHT : TRSM, C(mxn) = B(mxn) * A(nxn) */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | A | X | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | LEFT | mxm | mxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mxm | mx1 | mx1 | TRSV | */ \ + /* | m = 1 | 1x1 | 1xn | 1xn | INVSCALS | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | X | A | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | RIGHT | mxn | nxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mx1 | 1x1 | mx1 | Transpose and INVSCALS| */ \ + /* | m = 1 | 1xn | nxn | 1xn | Transpose and TRSV | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* If Transpose(A) uplo = lower then uplo = higher */ \ + /* If Transpose(A) uplo = higher then uplo = lower */ \ + /* ----------------------------------------------------------- */ \ \ if( n0 == 1 ) \ { \ diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 573ca71b2..c566cb436 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -79,12 +79,7 @@ void bli_sdotxf_zen_int_8 if ( bli_zero_dim1( m ) || PASTEMAC(s,eq0)( *alpha ) ) { -#ifdef BLIS_CONFIG_EPYC - sscalv_ker_ft f = bli_sscalv_zen_int10; -#else - sscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SCALV_KER, cntx ); -#endif - f + bli_sscalv_zen_int10 ( BLIS_NO_CONJUGATE, b_n, @@ -99,18 +94,13 @@ void bli_sdotxf_zen_int_8 // operation as a loop over dotxv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - sdotxv_ker_ft f = bli_sdotxv_zen_int; -#else - sdotxv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_DOTXV_KER, cntx ); -#endif for ( dim_t i = 0; i < b_n; ++i ) { float* a1 = a + (0 )*inca + (i )*lda; float* x1 = x + (0 )*incx; float* psi1 = y + (i )*incy; - f + bli_sdotxv_zen_int ( conjat, conjx, @@ -475,12 +465,7 @@ void bli_ddotxf_zen_int_8 // simplifies to updating y. if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) { -#ifdef BLIS_CONFIG_EPYC - dscalv_ker_ft f = bli_dscalv_zen_int10; -#else - dscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SCALV_KER, cntx ); -#endif - f + bli_dscalv_zen_int10 ( BLIS_NO_CONJUGATE, b_n, @@ -495,18 +480,13 @@ void bli_ddotxf_zen_int_8 // operation as a loop over dotxv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - ddotxv_ker_ft f = bli_ddotxv_zen_int; -#else - bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_DOTXV_KER, cntx ); -#endif for ( dim_t i = 0; i < b_n; ++i ) { double* a1 = a + (0 )*inca + (i )*lda; double* x1 = x + (0 )*incx; double* psi1 = y + (i )*incy; - f + bli_ddotxv_zen_int ( conjat, conjx,