CPUPL-929: Improve Complex GEMM performance - Support all storage formats and non Transpose/Conjugate Matrices

Failure was seen in libflame function (FLASH_UDdate_UT_inc)
Due to typecasting double complex pointer as double pointer

Change-Id: If6e2f4663575450a13a9a07dddd5622628f5c6b0
This commit is contained in:
managalv
2020-06-02 18:26:30 +05:30
parent 6f01cd2c54
commit b4e599ecc2
2 changed files with 49 additions and 50 deletions

View File

@@ -36,7 +36,6 @@
#include "blis.h"
#include "immintrin.h"
//GENTFUNC( scomplex, c, gemmsup_r_zen_ref_3x1, 3 )
/*
rrr:
-------- ------ --------
@@ -640,7 +639,7 @@ void bli_cgemmsup_rv_zen_asm_2x8n
/* (ar + ai) x AB */
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
ymm4 = _mm256_mul_ps(ymm0, ymm4);
@@ -697,8 +696,8 @@ void bli_cgemmsup_rv_zen_asm_2x8n
}
else{
ymm1 = _mm256_broadcast_ss((float const *)beta); // load alpha_r and duplicate
ymm2 = _mm256_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate
ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
//Multiply ymm4 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
@@ -943,7 +942,7 @@ void bli_cgemmsup_rv_zen_asm_1x8n
// This loop is processing MR x K
ymm0 = _mm256_loadu_ps((float const *)(tB + tb_inc_row * k_iter));
ymm1 = _mm256_loadu_ps((float const *)(tB + tb_inc_row * k_iter + 4));
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm2 = _mm256_broadcast_ss((float const *)(tA));
@@ -969,7 +968,7 @@ void bli_cgemmsup_rv_zen_asm_1x8n
/* (ar + ai) x AB */
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
ymm4 = _mm256_mul_ps(ymm0, ymm4);
@@ -1238,7 +1237,7 @@ void bli_cgemmsup_rv_zen_asm_3x4
/* (ar + ai) x AB */
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
ymm4 = _mm256_mul_ps(ymm0, ymm4);
@@ -1261,8 +1260,8 @@ void bli_cgemmsup_rv_zen_asm_3x4
{
//transpose 3x4
ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8)));
_mm_storeu_ps((float *)tC, _mm256_castps256_ps128(ymm0));
_mm_storel_pi((__m64 *)tC+2, _mm256_castps256_ps128(ymm12));
_mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0));
_mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12));
ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8)));
tC += tc_inc_col;
@@ -1307,10 +1306,10 @@ void bli_cgemmsup_rv_zen_asm_3x4
ymm8 = _mm256_add_ps(ymm8, ymm0);
//Multiply ymm12 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ;
xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*2) ;
xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*3) ;
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ;
xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ;
xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ;
ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ;
ymm3 = _mm256_permute_ps(ymm0, 0xb1);
ymm0 = _mm256_mul_ps(ymm1, ymm0);
@@ -1429,27 +1428,27 @@ void bli_cgemmsup_rv_zen_asm_3x2
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing MR x K
xmm0 = _mm_loadu_ps((float const *)tB + tb_inc_row * k_iter);
xmm0 = _mm_loadu_ps((float const *)(tB + tb_inc_row * k_iter));
//broadcasted matrix B elements are multiplied
//with matrix A columns.
xmm2 = _mm_broadcast_ss((float const *)tA);
xmm2 = _mm_broadcast_ss((float const *)(tA));
xmm4 = _mm_fmadd_ps(xmm0, xmm2, xmm4);
xmm2 = _mm_broadcast_ss((float const *)tA + ta_inc_row);
xmm2 = _mm_broadcast_ss((float const *)(tA + ta_inc_row));
xmm8 = _mm_fmadd_ps(xmm0, xmm2, xmm8);
xmm2 = _mm_broadcast_ss((float const *)tA + ta_inc_row*2);
xmm2 = _mm_broadcast_ss((float const *)(tA + ta_inc_row*2));
xmm12 = _mm_fmadd_ps(xmm0, xmm2, xmm12);
//Compute imag values
xmm2 = _mm_broadcast_ss((float const *)tAimag );
xmm2 = _mm_broadcast_ss((float const *)(tAimag ));
xmm6 = _mm_fmadd_ps(xmm0, xmm2, xmm6);
xmm2 = _mm_broadcast_ss((float const *)tAimag + ta_inc_row *2);
xmm2 = _mm_broadcast_ss((float const *)(tAimag + ta_inc_row *2));
xmm10 = _mm_fmadd_ps(xmm0, xmm2, xmm10);
xmm2 = _mm_broadcast_ss((float const *)tAimag + ta_inc_row *4);
xmm2 = _mm_broadcast_ss((float const *)(tAimag + ta_inc_row *4));
xmm14 = _mm_fmadd_ps(xmm0, xmm2, xmm14);
tA += ta_inc_col;
tAimag += ta_inc_col*2;
@@ -1468,8 +1467,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
// alpha, beta multiplication.
/* (ar + ai) x AB */
xmm0 = _mm_broadcast_ss((float const *)alpha); // load alpha_r and duplicate
xmm1 = _mm_broadcast_ss((float const *)&alpha->imag); // load alpha_i and duplicate
xmm0 = _mm_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
xmm1 = _mm_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
xmm3 = _mm_permute_ps(xmm4, 0xb1);
xmm4 = _mm_mul_ps(xmm0, xmm4);
@@ -1492,21 +1491,21 @@ void bli_cgemmsup_rv_zen_asm_3x2
{
//transpose 3x2
xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8)));
_mm_storeu_ps((float *)tC, xmm0);
_mm_storel_pi((__m64 *)tC+2, xmm12);
_mm_storeu_ps((float *)(tC ), xmm0);
_mm_storel_pi((__m64 *)(tC+2), xmm12);
xmm1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8)));
tC += tc_inc_col;
_mm_storeu_ps((float *)tC, xmm1);
_mm_storeh_pi((__m64 *)tC+2, xmm12);
_mm_storeu_ps((float *)(tC ), xmm1);
_mm_storeh_pi((__m64 *)(tC+2), xmm12);
}
else{
xmm1 = _mm_broadcast_ss((float const *)beta); // load alpha_r and duplicate
xmm2 = _mm_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate
xmm1 = _mm_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
xmm2 = _mm_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
//Multiply xmm4 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) tC) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) tC + tc_inc_col);
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col));
xmm3 = _mm_permute_ps(xmm0, 0xb1);
xmm0 = _mm_mul_ps(xmm1, xmm0);
xmm3 = _mm_mul_ps(xmm2, xmm3);
@@ -1514,8 +1513,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
xmm4 = _mm_add_ps(xmm4, xmm0);
//Multiply xmm8 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 1) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 1 + tc_inc_col) ;
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ;
xmm3 = _mm_permute_ps(xmm0, 0xb1);
xmm0 = _mm_mul_ps(xmm1, xmm0);
xmm3 = _mm_mul_ps(xmm2, xmm3);
@@ -1523,8 +1522,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
xmm8 = _mm_add_ps(xmm8, xmm0);
//Multiply xmm12 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ;
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ;
xmm3 = _mm_permute_ps(xmm0, 0xb1);
xmm0 = _mm_mul_ps(xmm1, xmm0);
xmm3 = _mm_mul_ps(xmm2, xmm3);
@@ -1533,13 +1532,13 @@ void bli_cgemmsup_rv_zen_asm_3x2
//transpose 3x2
xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8)));
_mm_storeu_ps((float *)tC, xmm0);
_mm_storel_pi((__m64 *)tC+2, xmm12);
_mm_storeu_ps((float *)(tC ), xmm0);
_mm_storel_pi((__m64 *)(tC+2), xmm12);
xmm3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8)));
tC += tc_inc_col;
_mm_storeu_ps((float *)tC, xmm3);
_mm_storeh_pi((__m64 *)tC+2, xmm12);
_mm_storeu_ps((float *)(tC ), xmm3);
_mm_storeh_pi((__m64 *)(tC+2), xmm12);
}
}
@@ -1547,36 +1546,36 @@ void bli_cgemmsup_rv_zen_asm_3x2
{
if(beta->real == 0.0 && beta->imag == 0.0)
{
_mm_storeu_ps((float *)tC, xmm4);
_mm_storeu_ps((float *)tC + tc_inc_row , xmm8);
_mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12);
_mm_storeu_ps((float *)(tC), xmm4);
_mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8);
_mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12);
}
else{
/* (br + bi) C + (ar + ai) AB */
xmm0 = _mm_broadcast_ss((float const *)beta); // load beta_r and duplicate
xmm1 = _mm_broadcast_ss((float const *)&beta->imag); // load beta_i and duplicate
xmm0 = _mm_broadcast_ss((float const *)(beta)); // load beta_r and duplicate
xmm1 = _mm_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate
xmm2 = _mm_loadu_ps((float const *)tC);
xmm2 = _mm_loadu_ps((float const *)(tC));
xmm3 = _mm_permute_ps(xmm2, 0xb1);
xmm2 = _mm_mul_ps(xmm0, xmm2);
xmm3 = _mm_mul_ps(xmm1, xmm3);
xmm4 = _mm_add_ps(xmm4, _mm_addsub_ps(xmm2, xmm3));
xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row);
xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row));
xmm3 = _mm_permute_ps(xmm2, 0xb1);
xmm2 = _mm_mul_ps(xmm0, xmm2);
xmm3 = _mm_mul_ps(xmm1, xmm3);
xmm8 = _mm_add_ps(xmm8, _mm_addsub_ps(xmm2, xmm3));
xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row*2);
xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row*2));
xmm3 = _mm_permute_ps(xmm2, 0xb1);
xmm2 = _mm_mul_ps(xmm0, xmm2);
xmm3 = _mm_mul_ps(xmm1, xmm3);
xmm12 = _mm_add_ps(xmm12, _mm_addsub_ps(xmm2, xmm3));
_mm_storeu_ps((float *)tC, xmm4);
_mm_storeu_ps((float *)tC + tc_inc_row , xmm8);
_mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12);;
_mm_storeu_ps((float *)(tC), xmm4);
_mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8);
_mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12);;
}
}
}

View File

@@ -1176,7 +1176,7 @@ void bli_zgemmsup_rv_zen_asm_3x2
ymm3 =_mm256_mul_pd(ymm1, ymm3);
ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3));
ymm2 = _mm256_loadu_pd((double const *)tC+tc_inc_row);
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row));
ymm3 = _mm256_permute_pd(ymm2, 5);
ymm2 = _mm256_mul_pd(ymm0, ymm2);
ymm3 = _mm256_mul_pd(ymm1, ymm3);