Fixed 32xk AVX512 double precision pack kernel

- Currently the pointer received as function argument is
  used for packing which causes only a partial copy of
  input buffer to output buffer due to strange optimizations
  by compiler.
- To fix this, instead of using a normal pointer for output
  buffer, we define a "restrict" local pointer variable.
- "restrict" keyword tells the compiler that the pointer is
  the only way to access the object pointed by the pointer.
- By defining "restrict" local pointer pointing to output
  buffer, the mysterious problem of incomplete copy has
  been solved.

Change-Id: Ie2355beb1d43ff4b60b940dd88c4e2bf6f361646
This commit is contained in:
Shubham
2023-02-16 23:23:40 +05:30
committed by Shubham Sharma
parent 3ae84c98fd
commit 1faee9f89e
2 changed files with 22 additions and 19 deletions

View File

@@ -106,7 +106,7 @@ void bli_cntx_init_zen4( cntx_t* cntx )
BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk,
BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk,
BLIS_PACKM_24XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_24xk,
BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_32xk_zen4_ref,
BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_32xk,
BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk,
BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk,
BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk,

View File

@@ -86,6 +86,7 @@ void bli_dpackm_zen4_asm_32xk
// assembly region, this constraint should be lifted.
const bool unitk = bli_deq1( *kappa );
double* restrict pi1 = p;
// -------------------------------------------------------------------------
@@ -100,10 +101,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dcopyjs( *(a + i), *(p + i) );
bli_dcopyjs( *(a + i), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
else
@@ -111,10 +112,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dcopyjs( *(a + i*inca), *(p + i) );
bli_dcopyjs( *(a + i*inca), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
}
@@ -126,10 +127,10 @@ void bli_dpackm_zen4_asm_32xk
{
_mm_prefetch( a + (8*lda), _MM_HINT_T0 );
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dcopys( *(a + i), *(p + i) );
bli_dcopys( *(a + i), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
else
@@ -137,10 +138,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dcopys( *(a + i*inca), *(p + i) );
bli_dcopys( *(a + i*inca), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
}
@@ -154,10 +155,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dscal2js( *kappa, *(a + i), *(p + i) );
bli_dscal2js( *kappa, *(a + i), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
else
@@ -165,10 +166,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dscal2js( *kappa, *(a + i*inca), *(p + i) );
bli_dscal2js( *kappa, *(a + i*inca), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
}
@@ -179,10 +180,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dscal2s( *kappa, *(a + i), *(p + i) );
bli_dscal2s( *kappa, *(a + i), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
else
@@ -190,10 +191,10 @@ void bli_dpackm_zen4_asm_32xk
for ( dim_t k = k0; k != 0; --k )
{
for ( dim_t i = 0 ; i < 32 ; i++ ) {
bli_dscal2s( *kappa, *(a + i*inca), *(p + i) );
bli_dscal2s( *kappa, *(a + i*inca), *(pi1 + i) );
}
a += lda;
p += ldp;
pi1 += ldp;
}
}
}
@@ -223,7 +224,8 @@ void bli_dpackm_zen4_asm_32xk
const dim_t i = cdim0;
const dim_t m_edge = mnr - cdim0;
const dim_t n_edge = k0_max;
double* restrict p_edge = p + (i )*1;
double* restrict p_cast = p;
double* restrict p_edge = p_cast + (i )*1;
bli_dset0s_mxn
(
@@ -241,7 +243,8 @@ void bli_dpackm_zen4_asm_32xk
const dim_t j = k0;
const dim_t m_edge = mnr;
const dim_t n_edge = k0_max - k0;
double* restrict p_edge = p + (j )*ldp;
double* restrict p_cast = p;
double* restrict p_edge = p_cast + (j )*ldp;
bli_dset0s_mxn
(