Performance Improvement for ztrsm small sizes

Details:
    - Handled Overflow and Underflow Vulnerabilites in
      ztrsm small right implementations.
    - Fixed failures observed in Scalapack testing.

    AMD-Internal: [CPUPL-2115]

Change-Id: I22c1ba583e0ba14d1a4684a85fa1ca6e152e8439
This commit is contained in:
satish kumar nuggu
2022-04-29 17:13:29 +05:30
committed by Dipal M Zambare
parent 2acb3f6ed0
commit 09b70de635

View File

@@ -34922,38 +34922,21 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB
{
if(transa)
{
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)
(a11+cs_a*1 + 1));
ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,
d11_pack,n_remainder);
}
else
{
//broadcast diagonal elements of A11
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)
(a11+rs_a*1 + 1));
ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,
d11_pack,n_remainder);
}
ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C);
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);
/*Taking denomerator multiplication of real &
* imaginary components*/
ymm4 = _mm256_mul_pd(ymm1, ymm1);
/*Swapping real & imaginary component position for addition with
* respective components*/
ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1);
ymm4 = _mm256_add_pd(ymm4, ymm6);
/*Negating imaginary component of numerator*/
ymm1 = _mm256_mul_pd(ymm1, ymm7);
/*Dividing numerator by denominator*/
ymm1 = _mm256_div_pd(ymm1, ymm4);
#endif
}
else
{
ymm1 = _mm256_broadcast_pd((__m128d const*)&ones);
_mm256_storeu_pd((double *)(d11_pack), ymm1);
}
_mm256_storeu_pd((double *)(d11_pack), ymm1);
for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction
{
a01 = D_A_pack;
@@ -35340,30 +35323,23 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB
}
if(!is_unitdiag)
{
//broadcast diagonal elements of A11
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)&ones);
ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C);
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);
/*Taking denomerator multiplication of real &
* imaginary components*/
ymm4 = _mm256_mul_pd(ymm1, ymm1);
/*Swapping real & imaginary component position for addition with
* respective components*/
ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1);
ymm4 = _mm256_add_pd(ymm4, ymm6);
/*Negating imaginary component of numerator*/
ymm1 = _mm256_mul_pd(ymm1, ymm7);
/*Dividing numerator by denominator*/
ymm1 = _mm256_div_pd(ymm1, ymm4);
#endif
if(transa)
{
ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,
d11_pack,n_remainder);
}
else
{
ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,
d11_pack,n_remainder);
}
}
else
{
ymm1 = _mm256_broadcast_pd((__m128d const*)&ones);
_mm256_storeu_pd((double *)(d11_pack), ymm1);
}
_mm256_storeu_pd((double *)(d11_pack), ymm1);
for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction
{
a01 = D_A_pack;
@@ -36374,39 +36350,20 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB
{
if(transa)
{
//broadcast diagonal elements of A11
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)
(a11+cs_a*1 + 1));
ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,
d11_pack,n_remainder);
}
else
{
//broadcast diagonal elements of A11
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)
(a11+rs_a*1 + 1));
ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,
d11_pack,n_remainder);
}
ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C);
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);
/*Taking denomerator multiplication of real &
* imaginary components*/
ymm4 = _mm256_mul_pd(ymm1, ymm1);
/*Swapping real & imaginary component position for addition with
* respective components*/
ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1);
ymm4 = _mm256_add_pd(ymm4, ymm6);
/*Negating imaginary component of numerator*/
ymm1 = _mm256_mul_pd(ymm1, ymm7);
/*Dividing numerator by denominator*/
ymm1 = _mm256_div_pd(ymm1, ymm4);
#endif
}
else
{
ymm1 = _mm256_broadcast_pd((__m128d const *)&ones);
_mm256_storeu_pd((double *)(d11_pack), ymm1);
}
_mm256_storeu_pd((double *)(d11_pack), ymm1);
for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction
{
@@ -36793,30 +36750,22 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB
}
if(!is_unitdiag)
{
//broadcast diagonal elements of A11
ymm0 = _mm256_broadcast_pd((__m128d const *)(a11));
ymm1 = _mm256_broadcast_pd((__m128d const *)&ones);
ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C);
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);
/*Taking denomerator multiplication of real &
* imaginary components*/
ymm4 = _mm256_mul_pd(ymm1, ymm1);
/*Swapping real & imaginary component position for addition with
* respective components*/
ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1);
ymm4 = _mm256_add_pd(ymm4, ymm6);
/*Negating imaginary component of numerator*/
ymm1 = _mm256_mul_pd(ymm1, ymm7);
/*Dividing numerator by denominator*/
ymm1 = _mm256_div_pd(ymm1, ymm4);
#endif
if(transa)
{
ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,
d11_pack,n_remainder);
}
else
{
ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,
d11_pack,n_remainder);
}
}
else
{
ymm1 = _mm256_broadcast_pd((__m128d const *)&ones);
}
_mm256_storeu_pd((double *)(d11_pack), ymm1);
_mm256_storeu_pd((double *)(d11_pack), ymm1);
}
for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction
{