mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
CPUPL-929: Improve Complex GEMM performance - Support all storage formats and non Transpose/Conjugate Matrices
Failure was seen in libflame function (FLASH_UDdate_UT_inc) Due to typecasting double complex pointer as double pointer Change-Id: If6e2f4663575450a13a9a07dddd5622628f5c6b0
This commit is contained in:
@@ -36,7 +36,6 @@
|
||||
#include "blis.h"
|
||||
#include "immintrin.h"
|
||||
|
||||
//GENTFUNC( scomplex, c, gemmsup_r_zen_ref_3x1, 3 )
|
||||
/*
|
||||
rrr:
|
||||
-------- ------ --------
|
||||
@@ -640,7 +639,7 @@ void bli_cgemmsup_rv_zen_asm_2x8n
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
|
||||
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
|
||||
ymm4 = _mm256_mul_ps(ymm0, ymm4);
|
||||
@@ -697,8 +696,8 @@ void bli_cgemmsup_rv_zen_asm_2x8n
|
||||
|
||||
}
|
||||
else{
|
||||
ymm1 = _mm256_broadcast_ss((float const *)beta); // load alpha_r and duplicate
|
||||
ymm2 = _mm256_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
|
||||
ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
|
||||
|
||||
//Multiply ymm4 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
|
||||
@@ -943,7 +942,7 @@ void bli_cgemmsup_rv_zen_asm_1x8n
|
||||
// This loop is processing MR x K
|
||||
ymm0 = _mm256_loadu_ps((float const *)(tB + tb_inc_row * k_iter));
|
||||
ymm1 = _mm256_loadu_ps((float const *)(tB + tb_inc_row * k_iter + 4));
|
||||
|
||||
|
||||
//broadcasted matrix B elements are multiplied
|
||||
//with matrix A columns.
|
||||
ymm2 = _mm256_broadcast_ss((float const *)(tA));
|
||||
@@ -969,7 +968,7 @@ void bli_cgemmsup_rv_zen_asm_1x8n
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
|
||||
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
|
||||
ymm4 = _mm256_mul_ps(ymm0, ymm4);
|
||||
@@ -1238,7 +1237,7 @@ void bli_cgemmsup_rv_zen_asm_3x4
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
ymm0 = _mm256_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)&(alpha->imag)); // load alpha_i and duplicate
|
||||
ymm1 = _mm256_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
|
||||
ymm3 = _mm256_permute_ps(ymm4, 0xb1);
|
||||
ymm4 = _mm256_mul_ps(ymm0, ymm4);
|
||||
@@ -1261,8 +1260,8 @@ void bli_cgemmsup_rv_zen_asm_3x4
|
||||
{
|
||||
//transpose 3x4
|
||||
ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8)));
|
||||
_mm_storeu_ps((float *)tC, _mm256_castps256_ps128(ymm0));
|
||||
_mm_storel_pi((__m64 *)tC+2, _mm256_castps256_ps128(ymm12));
|
||||
_mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0));
|
||||
_mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12));
|
||||
|
||||
ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8)));
|
||||
tC += tc_inc_col;
|
||||
@@ -1307,10 +1306,10 @@ void bli_cgemmsup_rv_zen_asm_3x4
|
||||
ymm8 = _mm256_add_ps(ymm8, ymm0);
|
||||
|
||||
//Multiply ymm12 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ;
|
||||
xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*2) ;
|
||||
xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*3) ;
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ;
|
||||
xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ;
|
||||
xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ;
|
||||
ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ;
|
||||
ymm3 = _mm256_permute_ps(ymm0, 0xb1);
|
||||
ymm0 = _mm256_mul_ps(ymm1, ymm0);
|
||||
@@ -1429,27 +1428,27 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
// The inner loop broadcasts the B matrix data and
|
||||
// multiplies it with the A matrix.
|
||||
// This loop is processing MR x K
|
||||
xmm0 = _mm_loadu_ps((float const *)tB + tb_inc_row * k_iter);
|
||||
xmm0 = _mm_loadu_ps((float const *)(tB + tb_inc_row * k_iter));
|
||||
|
||||
//broadcasted matrix B elements are multiplied
|
||||
//with matrix A columns.
|
||||
xmm2 = _mm_broadcast_ss((float const *)tA);
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tA));
|
||||
xmm4 = _mm_fmadd_ps(xmm0, xmm2, xmm4);
|
||||
|
||||
xmm2 = _mm_broadcast_ss((float const *)tA + ta_inc_row);
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tA + ta_inc_row));
|
||||
xmm8 = _mm_fmadd_ps(xmm0, xmm2, xmm8);
|
||||
|
||||
xmm2 = _mm_broadcast_ss((float const *)tA + ta_inc_row*2);
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tA + ta_inc_row*2));
|
||||
xmm12 = _mm_fmadd_ps(xmm0, xmm2, xmm12);
|
||||
|
||||
//Compute imag values
|
||||
xmm2 = _mm_broadcast_ss((float const *)tAimag );
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tAimag ));
|
||||
xmm6 = _mm_fmadd_ps(xmm0, xmm2, xmm6);
|
||||
|
||||
xmm2 = _mm_broadcast_ss((float const *)tAimag + ta_inc_row *2);
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tAimag + ta_inc_row *2));
|
||||
xmm10 = _mm_fmadd_ps(xmm0, xmm2, xmm10);
|
||||
|
||||
xmm2 = _mm_broadcast_ss((float const *)tAimag + ta_inc_row *4);
|
||||
xmm2 = _mm_broadcast_ss((float const *)(tAimag + ta_inc_row *4));
|
||||
xmm14 = _mm_fmadd_ps(xmm0, xmm2, xmm14);
|
||||
tA += ta_inc_col;
|
||||
tAimag += ta_inc_col*2;
|
||||
@@ -1468,8 +1467,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
// alpha, beta multiplication.
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
xmm0 = _mm_broadcast_ss((float const *)alpha); // load alpha_r and duplicate
|
||||
xmm1 = _mm_broadcast_ss((float const *)&alpha->imag); // load alpha_i and duplicate
|
||||
xmm0 = _mm_broadcast_ss((float const *)(alpha)); // load alpha_r and duplicate
|
||||
xmm1 = _mm_broadcast_ss((float const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
|
||||
xmm3 = _mm_permute_ps(xmm4, 0xb1);
|
||||
xmm4 = _mm_mul_ps(xmm0, xmm4);
|
||||
@@ -1492,21 +1491,21 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
{
|
||||
//transpose 3x2
|
||||
xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8)));
|
||||
_mm_storeu_ps((float *)tC, xmm0);
|
||||
_mm_storel_pi((__m64 *)tC+2, xmm12);
|
||||
_mm_storeu_ps((float *)(tC ), xmm0);
|
||||
_mm_storel_pi((__m64 *)(tC+2), xmm12);
|
||||
|
||||
xmm1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8)));
|
||||
tC += tc_inc_col;
|
||||
_mm_storeu_ps((float *)tC, xmm1);
|
||||
_mm_storeh_pi((__m64 *)tC+2, xmm12);
|
||||
_mm_storeu_ps((float *)(tC ), xmm1);
|
||||
_mm_storeh_pi((__m64 *)(tC+2), xmm12);
|
||||
}
|
||||
else{
|
||||
xmm1 = _mm_broadcast_ss((float const *)beta); // load alpha_r and duplicate
|
||||
xmm2 = _mm_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate
|
||||
xmm1 = _mm_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
|
||||
xmm2 = _mm_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
|
||||
|
||||
//Multiply xmm4 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) tC) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) tC + tc_inc_col);
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col));
|
||||
xmm3 = _mm_permute_ps(xmm0, 0xb1);
|
||||
xmm0 = _mm_mul_ps(xmm1, xmm0);
|
||||
xmm3 = _mm_mul_ps(xmm2, xmm3);
|
||||
@@ -1514,8 +1513,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
xmm4 = _mm_add_ps(xmm4, xmm0);
|
||||
|
||||
//Multiply xmm8 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 1) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 1 + tc_inc_col) ;
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ;
|
||||
xmm3 = _mm_permute_ps(xmm0, 0xb1);
|
||||
xmm0 = _mm_mul_ps(xmm1, xmm0);
|
||||
xmm3 = _mm_mul_ps(xmm2, xmm3);
|
||||
@@ -1523,8 +1522,8 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
xmm8 = _mm_add_ps(xmm8, xmm0);
|
||||
|
||||
//Multiply xmm12 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ;
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ;
|
||||
xmm3 = _mm_permute_ps(xmm0, 0xb1);
|
||||
xmm0 = _mm_mul_ps(xmm1, xmm0);
|
||||
xmm3 = _mm_mul_ps(xmm2, xmm3);
|
||||
@@ -1533,13 +1532,13 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
|
||||
//transpose 3x2
|
||||
xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8)));
|
||||
_mm_storeu_ps((float *)tC, xmm0);
|
||||
_mm_storel_pi((__m64 *)tC+2, xmm12);
|
||||
_mm_storeu_ps((float *)(tC ), xmm0);
|
||||
_mm_storel_pi((__m64 *)(tC+2), xmm12);
|
||||
|
||||
xmm3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8)));
|
||||
tC += tc_inc_col;
|
||||
_mm_storeu_ps((float *)tC, xmm3);
|
||||
_mm_storeh_pi((__m64 *)tC+2, xmm12);
|
||||
_mm_storeu_ps((float *)(tC ), xmm3);
|
||||
_mm_storeh_pi((__m64 *)(tC+2), xmm12);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1547,36 +1546,36 @@ void bli_cgemmsup_rv_zen_asm_3x2
|
||||
{
|
||||
if(beta->real == 0.0 && beta->imag == 0.0)
|
||||
{
|
||||
_mm_storeu_ps((float *)tC, xmm4);
|
||||
_mm_storeu_ps((float *)tC + tc_inc_row , xmm8);
|
||||
_mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12);
|
||||
_mm_storeu_ps((float *)(tC), xmm4);
|
||||
_mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8);
|
||||
_mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12);
|
||||
}
|
||||
else{
|
||||
/* (br + bi) C + (ar + ai) AB */
|
||||
xmm0 = _mm_broadcast_ss((float const *)beta); // load beta_r and duplicate
|
||||
xmm1 = _mm_broadcast_ss((float const *)&beta->imag); // load beta_i and duplicate
|
||||
xmm0 = _mm_broadcast_ss((float const *)(beta)); // load beta_r and duplicate
|
||||
xmm1 = _mm_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate
|
||||
|
||||
xmm2 = _mm_loadu_ps((float const *)tC);
|
||||
xmm2 = _mm_loadu_ps((float const *)(tC));
|
||||
xmm3 = _mm_permute_ps(xmm2, 0xb1);
|
||||
xmm2 = _mm_mul_ps(xmm0, xmm2);
|
||||
xmm3 = _mm_mul_ps(xmm1, xmm3);
|
||||
xmm4 = _mm_add_ps(xmm4, _mm_addsub_ps(xmm2, xmm3));
|
||||
|
||||
xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row);
|
||||
xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row));
|
||||
xmm3 = _mm_permute_ps(xmm2, 0xb1);
|
||||
xmm2 = _mm_mul_ps(xmm0, xmm2);
|
||||
xmm3 = _mm_mul_ps(xmm1, xmm3);
|
||||
xmm8 = _mm_add_ps(xmm8, _mm_addsub_ps(xmm2, xmm3));
|
||||
|
||||
xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row*2);
|
||||
xmm2 = _mm_loadu_ps((float const *)(tC+tc_inc_row*2));
|
||||
xmm3 = _mm_permute_ps(xmm2, 0xb1);
|
||||
xmm2 = _mm_mul_ps(xmm0, xmm2);
|
||||
xmm3 = _mm_mul_ps(xmm1, xmm3);
|
||||
xmm12 = _mm_add_ps(xmm12, _mm_addsub_ps(xmm2, xmm3));
|
||||
|
||||
_mm_storeu_ps((float *)tC, xmm4);
|
||||
_mm_storeu_ps((float *)tC + tc_inc_row , xmm8);
|
||||
_mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12);;
|
||||
_mm_storeu_ps((float *)(tC), xmm4);
|
||||
_mm_storeu_ps((float *)(tC + tc_inc_row) , xmm8);
|
||||
_mm_storeu_ps((float *)(tC + tc_inc_row *2), xmm12);;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1176,7 +1176,7 @@ void bli_zgemmsup_rv_zen_asm_3x2
|
||||
ymm3 =_mm256_mul_pd(ymm1, ymm3);
|
||||
ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3));
|
||||
|
||||
ymm2 = _mm256_loadu_pd((double const *)tC+tc_inc_row);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row));
|
||||
ymm3 = _mm256_permute_pd(ymm2, 5);
|
||||
ymm2 = _mm256_mul_pd(ymm0, ymm2);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
|
||||
Reference in New Issue
Block a user