From 09b70de6352ea5fa3b621d4b60320e37d9cd08a9 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 29 Apr 2022 17:13:29 +0530 Subject: [PATCH] Performance Improvement for ztrsm small sizes Details: - Handled Overflow and Underflow Vulnerabilites in ztrsm small right implementations. - Fixed failures observed in Scalapack testing. AMD-Internal: [CPUPL-2115] Change-Id: I22c1ba583e0ba14d1a4684a85fa1ca6e152e8439 --- kernels/zen/3/bli_trsm_small.c | 121 ++++++++++----------------------- 1 file changed, 35 insertions(+), 86 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index f8c0ea591..d7192a062 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -34922,38 +34922,21 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB { if(transa) { - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -35340,30 +35323,23 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -36374,39 +36350,20 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { @@ -36793,30 +36750,22 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm1); + } for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction {