mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Complex single standalone gemv implementation independent of axpyf.
Details - For axpyf implementation there are function(axpyf) calling overhead. - New implementations reduces function calling overhead. - This implementation uses kernel of size 8x4. - This implementation gives better performance for smaller sizes when compared to axpyf based implementation AMD-Internal: [CPUPL-1402] Change-Id: Ic9a5e59363290caf26284548638da9065952fd48
This commit is contained in:
committed by
Dipal M Zambare
parent
d6fcfe7345
commit
cbd9ea76af
@@ -498,14 +498,14 @@ void bli_cgemv_unf_var2
|
||||
/* If beta is zero, use setv. Otherwise, scale by beta. */
|
||||
/* y = beta * y; */
|
||||
/* beta=0 case is hadled by scalv internally */
|
||||
bli_cscalv_ex
|
||||
bli_cscalv_zen_int10
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
n_elem,
|
||||
beta,
|
||||
y, incy,
|
||||
cntx,
|
||||
NULL
|
||||
y,
|
||||
incy,
|
||||
cntx
|
||||
);
|
||||
|
||||
if( bli_ceq0( *alpha ) )
|
||||
@@ -513,30 +513,59 @@ void bli_cgemv_unf_var2
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
|
||||
return;
|
||||
}
|
||||
/* fusing factor. */
|
||||
b_fuse = 4;
|
||||
|
||||
for ( i = 0; i < n_iter; i += f )
|
||||
// for non-unit incx, incy and rs_at and conjugate will be added in the next patch
|
||||
if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) &&
|
||||
!bli_is_conj(conja) && !bli_is_conj(conjx) &&
|
||||
!bli_is_trans(transa))
|
||||
{
|
||||
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
|
||||
A1 = a + (0 )*rs_at + (i )*cs_at;
|
||||
x1 = x + (i )*incx;
|
||||
y1 = y + (0 )*incy;
|
||||
|
||||
/* y = y + alpha * A1 * x1; */
|
||||
bli_caxpyf_zen_int_4
|
||||
// This gemv code deals with the followint conditions only
|
||||
// 1. incx, incy, and row stride equal to one
|
||||
// 2. Non conjugate A matrix and X vector
|
||||
// 3. No Transpose for A Martix
|
||||
// Rest is taken care by the else part (axpyf implementation)
|
||||
bli_cgemv_zen_int_4x4
|
||||
(
|
||||
conja,
|
||||
conjx,
|
||||
n_elem,
|
||||
f,
|
||||
alpha,
|
||||
A1, rs_at, cs_at,
|
||||
x1, incx,
|
||||
y1, incy,
|
||||
NULL
|
||||
conja,
|
||||
conjx,
|
||||
m,
|
||||
n,
|
||||
alpha,
|
||||
a, rs_at, cs_at,
|
||||
x, incx,
|
||||
beta,
|
||||
y, incy,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
/* fusing factor. */
|
||||
b_fuse = 4;
|
||||
|
||||
for ( i = 0; i < n_iter; i += f )
|
||||
{
|
||||
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
|
||||
A1 = a + (0 )*rs_at + (i )*cs_at;
|
||||
x1 = x + (i )*incx;
|
||||
y1 = y + (0 )*incy;
|
||||
|
||||
/* y = y + alpha * A1 * x1; */
|
||||
bli_caxpyf_zen_int_4
|
||||
(
|
||||
conja,
|
||||
conjx,
|
||||
n_elem,
|
||||
f,
|
||||
alpha,
|
||||
A1, rs_at, cs_at,
|
||||
x1, incx,
|
||||
y1, incy,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
|
||||
}
|
||||
|
||||
|
||||
@@ -257,4 +257,223 @@ void bli_zgemv_zen_int_4x4
|
||||
cntx
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
This implementation uses 512 bits of cache line efficiently for
|
||||
column stored matrix and vectors.
|
||||
To achieve this, at each iteration we use 2 ymm registers
|
||||
i.e. .512 bits for arithmetic operation. By this we use the
|
||||
cache efficiently.
|
||||
*/
|
||||
void bli_cgemv_zen_int_4x4
|
||||
(
|
||||
conj_t conja,
|
||||
conj_t conjx,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a, inc_t inca, inc_t lda,
|
||||
scomplex* restrict x, inc_t incx,
|
||||
scomplex* restrict beta,
|
||||
scomplex* restrict y, inc_t incy,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
|
||||
const dim_t S_MR = 8; // Kernel size , m = 8
|
||||
const dim_t S_NR = 4; // Kernel size , n = 4
|
||||
|
||||
scomplex chi0;
|
||||
scomplex chi1;
|
||||
scomplex chi2;
|
||||
scomplex chi3;
|
||||
|
||||
inc_t lda2 = 2*lda;
|
||||
inc_t lda3 = 3*lda;
|
||||
inc_t incy4 = 4*incy;
|
||||
inc_t incx2 = 2*incx;
|
||||
inc_t incx3 = 3*incx;
|
||||
inc_t inca2 = 4*inca;
|
||||
|
||||
scomplex* x0 = x;
|
||||
scomplex* y0 = y;
|
||||
scomplex* a0 = a;
|
||||
|
||||
dim_t i,j;
|
||||
|
||||
__m256 ymm0, ymm1, ymm2, ymm3;
|
||||
__m256 ymm4, ymm5, ymm6, ymm7;
|
||||
__m256 ymm8, ymm9, ymm10, ymm11;
|
||||
__m256 ymm12, ymm13, ymm14, ymm15;
|
||||
|
||||
for( i = 0; i+S_NR-1 < n; i+=S_NR )
|
||||
{
|
||||
a0 = a + (i )*lda;
|
||||
x0 = x + (i )*incx;
|
||||
y0 = y;// For each kernel, y should start form beginning
|
||||
|
||||
chi0 = *( x0);
|
||||
chi1 = *( x0 + incx );
|
||||
chi2 = *( x0 + incx2 );
|
||||
chi3 = *( x0 + incx3 );
|
||||
|
||||
bli_cscals( *alpha, chi0 );
|
||||
bli_cscals( *alpha, chi1 );
|
||||
bli_cscals( *alpha, chi2 );
|
||||
bli_cscals( *alpha, chi3 );
|
||||
|
||||
ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0
|
||||
ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0
|
||||
ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1
|
||||
ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1
|
||||
ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2
|
||||
ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2
|
||||
ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3
|
||||
ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3
|
||||
|
||||
for( j = 0 ; j+S_MR-1 < m ; j+=S_MR )
|
||||
{
|
||||
//load columns of A, each ymm reg had 4 elements
|
||||
ymm8 = _mm256_loadu_ps((float const *)(a0));
|
||||
ymm9 = _mm256_loadu_ps((float const *)(a0 + lda));
|
||||
ymm10 = _mm256_loadu_ps((float const *)(a0 + lda2));
|
||||
ymm11 = _mm256_loadu_ps((float const *)(a0 + lda3));
|
||||
|
||||
//--------------------
|
||||
//Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr
|
||||
ymm14 = _mm256_mul_ps(ymm8, ymm0);
|
||||
//Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi
|
||||
ymm15 = _mm256_mul_ps(ymm8, ymm1);
|
||||
|
||||
/* Next set of A mult by real and imag,
|
||||
Add into the previous real and imag results */
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
|
||||
// + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
|
||||
// + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15);
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
|
||||
// + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
|
||||
// + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15);
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr)
|
||||
// + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
|
||||
// + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15);
|
||||
/*Permute the imag acc register to addsub to real accu results */
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi)
|
||||
// => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi)
|
||||
ymm15 = _mm256_permute_ps(ymm15, 0xB1);
|
||||
/*AddSub to get the 2 proper complex multipled value*/
|
||||
/* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi,
|
||||
Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/
|
||||
ymm12 = _mm256_addsub_ps(ymm14, ymm15);
|
||||
//load Y vector
|
||||
ymm14 = _mm256_loadu_ps((float*)y0);
|
||||
//Add the results into y
|
||||
ymm12 = _mm256_add_ps(ymm14, ymm12);
|
||||
// Store the results back
|
||||
_mm256_storeu_ps((float*)(y0), ymm12);
|
||||
|
||||
//-----------------------
|
||||
|
||||
// Load Next Set of A matrix elements for the same col
|
||||
// Ar2 Ai2 Ar3 Ai3
|
||||
ymm8 = _mm256_loadu_ps((float const *)(a0 + (inca2)));
|
||||
ymm9 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda));
|
||||
ymm10 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda2));
|
||||
ymm11 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda3));
|
||||
|
||||
//Ar0*Xr Ai0*Xr Ar1*Xr Ai1*Xr
|
||||
ymm14 = _mm256_mul_ps(ymm8, ymm0);
|
||||
//Ar0*Xi Ai0*Xi Ar1*Xi Ai1*Xi
|
||||
ymm15 = _mm256_mul_ps(ymm8, ymm1);
|
||||
|
||||
/* Next set of A mult by real and imag,
|
||||
Add into the previous real and imag results */
|
||||
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15);
|
||||
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15);
|
||||
|
||||
// (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results)
|
||||
ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14);
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results)
|
||||
ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15);
|
||||
|
||||
/*Permute the imag acc register to addsub to real accu results */
|
||||
// (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi)
|
||||
ymm15 = _mm256_permute_ps(ymm15, 0xB1);
|
||||
/*AddSub to get the 2 proper complex multipled value*/
|
||||
/* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/
|
||||
ymm13 = _mm256_addsub_ps(ymm14, ymm15);
|
||||
|
||||
// load Y vector
|
||||
ymm14 = _mm256_loadu_ps((float *)(y0 + (incy4)));
|
||||
// Add the results into y
|
||||
ymm13 = _mm256_add_ps(ymm14, ymm13);
|
||||
// Store the results back
|
||||
_mm256_storeu_ps((float*)(y0 + (incy4)), ymm13);
|
||||
|
||||
y0 += S_MR*incy ; // Next Set of y0 vector
|
||||
a0 += S_MR*inca ; // Next Set of a0 matrix elements in the same col
|
||||
}
|
||||
|
||||
// For resisual m
|
||||
for( ; j < m ; ++j )
|
||||
{
|
||||
scomplex y0c = *(scomplex*)y0;
|
||||
const scomplex a0c = *a0;
|
||||
const scomplex a1c = *(a0 + lda);
|
||||
const scomplex a2c = *(a0 + lda2);
|
||||
const scomplex a3c = *(a0 + lda3);
|
||||
|
||||
y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag;
|
||||
y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag;
|
||||
y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag;
|
||||
y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag;
|
||||
|
||||
y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag;
|
||||
y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag;
|
||||
y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag;
|
||||
y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag;
|
||||
|
||||
*(scomplex*)y0 = y0c;
|
||||
a0 += 1;
|
||||
y0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// For resisual n, axpyv is used
|
||||
for ( ; i < n; ++i )
|
||||
{
|
||||
scomplex* a1 = a + (i )*lda;
|
||||
scomplex* chi1 = x + (i )*incx;
|
||||
scomplex* y1 = y;
|
||||
scomplex alpha_chi1;
|
||||
bli_ccopycjs( conjx, *chi1, alpha_chi1 );
|
||||
bli_cscals( *alpha, alpha_chi1 );
|
||||
bli_caxpyv_zen_int5
|
||||
(
|
||||
conja,
|
||||
m,
|
||||
&alpha_chi1,
|
||||
a1, inca,
|
||||
y1, incy,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ DOTXF_KER_PROT( double, d, dotxf_zen_int_8 )
|
||||
|
||||
//gemv(scalar code)
|
||||
GEMV_KER_PROT( double, d, gemv_zen_ref_c )
|
||||
GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 )
|
||||
GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 )
|
||||
|
||||
// -- level-3 sup --------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user