diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 34c11f758..ffebf17ba 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -498,14 +498,14 @@ void bli_cgemv_unf_var2 /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ - bli_cscalv_ex + bli_cscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, beta, - y, incy, - cntx, - NULL + y, + incy, + cntx ); if( bli_ceq0( *alpha ) ) @@ -513,30 +513,59 @@ void bli_cgemv_unf_var2 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) return; } - /* fusing factor. */ - b_fuse = 4; - for ( i = 0; i < n_iter; i += f ) + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && + !bli_is_trans(transa)) { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_caxpyf_zen_int_4 + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_cgemv_zen_int_4x4 ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + NULL ); } + else + { + /* fusing factor. */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index 95060f57e..b3c92b551 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -257,4 +257,223 @@ void bli_zgemv_zen_int_4x4 cntx ); } -} \ No newline at end of file +} + +/* + This implementation uses 512 bits of cache line efficiently for + column stored matrix and vectors. + To achieve this, at each iteration we use 2 ymm registers + i.e. .512 bits for arithmetic operation. By this we use the + cache efficiently. +*/ +void bli_cgemv_zen_int_4x4 +( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + + const dim_t S_MR = 8; // Kernel size , m = 8 + const dim_t S_NR = 4; // Kernel size , n = 4 + + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + inc_t lda2 = 2*lda; + inc_t lda3 = 3*lda; + inc_t incy4 = 4*incy; + inc_t incx2 = 2*incx; + inc_t incx3 = 3*incx; + inc_t inca2 = 4*inca; + + scomplex* x0 = x; + scomplex* y0 = y; + scomplex* a0 = a; + + dim_t i,j; + + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + + for( i = 0; i+S_NR-1 < n; i+=S_NR ) + { + a0 = a + (i )*lda; + x0 = x + (i )*incx; + y0 = y;// For each kernel, y should start form beginning + + chi0 = *( x0); + chi1 = *( x0 + incx ); + chi2 = *( x0 + incx2 ); + chi3 = *( x0 + incx3 ); + + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0 + ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0 + ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1 + ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1 + ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2 + ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2 + ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3 + ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3 + + for( j = 0 ; j+S_MR-1 < m ; j+=S_MR ) + { + //load columns of A, each ymm reg had 4 elements + ymm8 = _mm256_loadu_ps((float const *)(a0)); + ymm9 = _mm256_loadu_ps((float const *)(a0 + lda)); + ymm10 = _mm256_loadu_ps((float const *)(a0 + lda2)); + ymm11 = _mm256_loadu_ps((float const *)(a0 + lda3)); + + //-------------------- + //Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr + ymm14 = _mm256_mul_ps(ymm8, ymm0); + //Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi + ymm15 = _mm256_mul_ps(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15); + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15); + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15); + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_ps(ymm15, 0xB1); + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, + Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm12 = _mm256_addsub_ps(ymm14, ymm15); + //load Y vector + ymm14 = _mm256_loadu_ps((float*)y0); + //Add the results into y + ymm12 = _mm256_add_ps(ymm14, ymm12); + // Store the results back + _mm256_storeu_ps((float*)(y0), ymm12); + +//----------------------- + + // Load Next Set of A matrix elements for the same col + // Ar2 Ai2 Ar3 Ai3 + ymm8 = _mm256_loadu_ps((float const *)(a0 + (inca2))); + ymm9 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda)); + ymm10 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda2)); + ymm11 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda3)); + + //Ar0*Xr Ai0*Xr Ar1*Xr Ai1*Xr + ymm14 = _mm256_mul_ps(ymm8, ymm0); + //Ar0*Xi Ai0*Xi Ar1*Xi Ai1*Xi + ymm15 = _mm256_mul_ps(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15); + + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_ps(ymm15, 0xB1); + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm13 = _mm256_addsub_ps(ymm14, ymm15); + + // load Y vector + ymm14 = _mm256_loadu_ps((float *)(y0 + (incy4))); + // Add the results into y + ymm13 = _mm256_add_ps(ymm14, ymm13); + // Store the results back + _mm256_storeu_ps((float*)(y0 + (incy4)), ymm13); + + y0 += S_MR*incy ; // Next Set of y0 vector + a0 += S_MR*inca ; // Next Set of a0 matrix elements in the same col + } + + // For resisual m + for( ; j < m ; ++j ) + { + scomplex y0c = *(scomplex*)y0; + const scomplex a0c = *a0; + const scomplex a1c = *(a0 + lda); + const scomplex a2c = *(a0 + lda2); + const scomplex a3c = *(a0 + lda3); + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag; + + *(scomplex*)y0 = y0c; + a0 += 1; + y0 += 1; + } + } + + // For resisual n, axpyv is used + for ( ; i < n; ++i ) + { + scomplex* a1 = a + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y; + scomplex alpha_chi1; + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + bli_caxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } +} + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index b39ccec57..884599696 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -117,6 +117,7 @@ DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) //gemv(scalar code) GEMV_KER_PROT( double, d, gemv_zen_ref_c ) +GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 ) GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 ) // -- level-3 sup --------------------------------------------------------------