Complex single standalone gemv implementation independent of axpyf.

Details
- For axpyf implementation there are function(axpyf) calling overhead.
- New implementations reduces function calling overhead.
- This implementation uses kernel of size 8x4.
- This implementation gives better performance for smaller sizes when
  compared to axpyf based implementation

AMD-Internal: [CPUPL-1402]
Change-Id: Ic9a5e59363290caf26284548638da9065952fd48
This commit is contained in:
Nageshwar Singh
2021-09-20 14:33:08 +05:30
committed by Dipal M Zambare
parent d6fcfe7345
commit cbd9ea76af
3 changed files with 273 additions and 24 deletions

View File

@@ -498,14 +498,14 @@ void bli_cgemv_unf_var2
/* If beta is zero, use setv. Otherwise, scale by beta. */
/* y = beta * y; */
/* beta=0 case is hadled by scalv internally */
bli_cscalv_ex
bli_cscalv_zen_int10
(
BLIS_NO_CONJUGATE,
n_elem,
beta,
y, incy,
cntx,
NULL
y,
incy,
cntx
);
if( bli_ceq0( *alpha ) )
@@ -513,30 +513,59 @@ void bli_cgemv_unf_var2
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
return;
}
/* fusing factor. */
b_fuse = 4;
for ( i = 0; i < n_iter; i += f )
// for non-unit incx, incy and rs_at and conjugate will be added in the next patch
if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) &&
!bli_is_conj(conja) && !bli_is_conj(conjx) &&
!bli_is_trans(transa))
{
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
A1 = a + (0 )*rs_at + (i )*cs_at;
x1 = x + (i )*incx;
y1 = y + (0 )*incy;
/* y = y + alpha * A1 * x1; */
bli_caxpyf_zen_int_4
// This gemv code deals with the followint conditions only
// 1. incx, incy, and row stride equal to one
// 2. Non conjugate A matrix and X vector
// 3. No Transpose for A Martix
// Rest is taken care by the else part (axpyf implementation)
bli_cgemv_zen_int_4x4
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y1, incy,
NULL
conja,
conjx,
m,
n,
alpha,
a, rs_at, cs_at,
x, incx,
beta,
y, incy,
NULL
);
}
else
{
/* fusing factor. */
b_fuse = 4;
for ( i = 0; i < n_iter; i += f )
{
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
A1 = a + (0 )*rs_at + (i )*cs_at;
x1 = x + (i )*incx;
y1 = y + (0 )*incy;
/* y = y + alpha * A1 * x1; */
bli_caxpyf_zen_int_4
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y1, incy,
NULL
);
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}

View File

@@ -257,4 +257,223 @@ void bli_zgemv_zen_int_4x4
cntx
);
}
}
}
/*
This implementation uses 512 bits of cache line efficiently for
column stored matrix and vectors.
To achieve this, at each iteration we use 2 ymm registers
i.e. .512 bits for arithmetic operation. By this we use the
cache efficiently.
*/
void bli_cgemv_zen_int_4x4
(
conj_t conja,
conj_t conjx,
dim_t m,
dim_t n,
scomplex* restrict alpha,
scomplex* restrict a, inc_t inca, inc_t lda,
scomplex* restrict x, inc_t incx,
scomplex* restrict beta,
scomplex* restrict y, inc_t incy,
cntx_t* restrict cntx
)
{
const dim_t S_MR = 8; // Kernel size , m = 8
const dim_t S_NR = 4; // Kernel size , n = 4
scomplex chi0;
scomplex chi1;
scomplex chi2;
scomplex chi3;
inc_t lda2 = 2*lda;
inc_t lda3 = 3*lda;
inc_t incy4 = 4*incy;
inc_t incx2 = 2*incx;
inc_t incx3 = 3*incx;
inc_t inca2 = 4*inca;
scomplex* x0 = x;
scomplex* y0 = y;
scomplex* a0 = a;
dim_t i,j;
__m256 ymm0, ymm1, ymm2, ymm3;
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
for( i = 0; i+S_NR-1 < n; i+=S_NR )
{
a0 = a + (i )*lda;
x0 = x + (i )*incx;
y0 = y;// For each kernel, y should start form beginning
chi0 = *( x0);
chi1 = *( x0 + incx );
chi2 = *( x0 + incx2 );
chi3 = *( x0 + incx3 );
bli_cscals( *alpha, chi0 );
bli_cscals( *alpha, chi1 );
bli_cscals( *alpha, chi2 );
bli_cscals( *alpha, chi3 );
ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0
ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0
ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1
ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1
ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2
ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2
ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3
ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3
for( j = 0 ; j+S_MR-1 < m ; j+=S_MR )
{
//load columns of A, each ymm reg had 4 elements
ymm8 = _mm256_loadu_ps((float const *)(a0));
ymm9 = _mm256_loadu_ps((float const *)(a0 + lda));
ymm10 = _mm256_loadu_ps((float const *)(a0 + lda2));
ymm11 = _mm256_loadu_ps((float const *)(a0 + lda3));
//--------------------
//Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr
ymm14 = _mm256_mul_ps(ymm8, ymm0);
//Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi
ymm15 = _mm256_mul_ps(ymm8, ymm1);
/* Next set of A mult by real and imag,
Add into the previous real and imag results */
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
// + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
// + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15);
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
// + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
// + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15);
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
// + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
// + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15);
/*Permute the imag acc register to addsub to real accu results */
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
// => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi)
ymm15 = _mm256_permute_ps(ymm15, 0xB1);
/*AddSub to get the 2 proper complex multipled value*/
/* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi,
Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/
ymm12 = _mm256_addsub_ps(ymm14, ymm15);
//load Y vector
ymm14 = _mm256_loadu_ps((float*)y0);
//Add the results into y
ymm12 = _mm256_add_ps(ymm14, ymm12);
// Store the results back
_mm256_storeu_ps((float*)(y0), ymm12);
//-----------------------
// Load Next Set of A matrix elements for the same col
// Ar2 Ai2 Ar3 Ai3
ymm8 = _mm256_loadu_ps((float const *)(a0 + (inca2)));
ymm9 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda));
ymm10 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda2));
ymm11 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda3));
//Ar0*Xr Ai0*Xr Ar1*Xr Ai1*Xr
ymm14 = _mm256_mul_ps(ymm8, ymm0);
//Ar0*Xi Ai0*Xi Ar1*Xi Ai1*Xi
ymm15 = _mm256_mul_ps(ymm8, ymm1);
/* Next set of A mult by real and imag,
Add into the previous real and imag results */
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15);
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15);
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14);
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15);
/*Permute the imag acc register to addsub to real accu results */
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi)
ymm15 = _mm256_permute_ps(ymm15, 0xB1);
/*AddSub to get the 2 proper complex multipled value*/
/* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/
ymm13 = _mm256_addsub_ps(ymm14, ymm15);
// load Y vector
ymm14 = _mm256_loadu_ps((float *)(y0 + (incy4)));
// Add the results into y
ymm13 = _mm256_add_ps(ymm14, ymm13);
// Store the results back
_mm256_storeu_ps((float*)(y0 + (incy4)), ymm13);
y0 += S_MR*incy ; // Next Set of y0 vector
a0 += S_MR*inca ; // Next Set of a0 matrix elements in the same col
}
// For resisual m
for( ; j < m ; ++j )
{
scomplex y0c = *(scomplex*)y0;
const scomplex a0c = *a0;
const scomplex a1c = *(a0 + lda);
const scomplex a2c = *(a0 + lda2);
const scomplex a3c = *(a0 + lda3);
y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag;
y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag;
y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag;
y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag;
y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag;
y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag;
y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag;
y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag;
*(scomplex*)y0 = y0c;
a0 += 1;
y0 += 1;
}
}
// For resisual n, axpyv is used
for ( ; i < n; ++i )
{
scomplex* a1 = a + (i )*lda;
scomplex* chi1 = x + (i )*incx;
scomplex* y1 = y;
scomplex alpha_chi1;
bli_ccopycjs( conjx, *chi1, alpha_chi1 );
bli_cscals( *alpha, alpha_chi1 );
bli_caxpyv_zen_int5
(
conja,
m,
&alpha_chi1,
a1, inca,
y1, incy,
cntx
);
}
}

View File

@@ -117,6 +117,7 @@ DOTXF_KER_PROT( double, d, dotxf_zen_int_8 )
//gemv(scalar code)
GEMV_KER_PROT( double, d, gemv_zen_ref_c )
GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 )
GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 )
// -- level-3 sup --------------------------------------------------------------