From 18ae57305ebece998ca4c6c1786e6ba8983df976 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Tue, 24 Jan 2023 18:03:33 +0530 Subject: [PATCH] ZAXPYF4 optimization - Vectorized alpha scaling of X vector using SSE instructions. This can be done irrespective of incx. - Added code to prefetch A matrix and Y vector to L1 cache - Vectorized fringe case computation and non-unit stride computation with SSE instructions. - Increased unroll in unit stride cases for better register utilization. AMD-Internal: [CPUPL-2773] Change-Id: I217e6ce9e3f5753ebe271c684abd9a2274fd2715 --- kernels/zen/1f/bli_axpyf_zen_int_4.c | 586 ++++++++++++++++++++------- 1 file changed, 439 insertions(+), 147 deletions(-) diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index bb24e6c52..d7c9ae372 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -304,106 +304,154 @@ void bli_zaxpyf_zen_int_4 cntx_t* restrict cntx ) { - inc_t fuse_fac = 4; - inc_t i; + dim_t fuse_fac = 4; - v4df_t ymm0, ymm1, ymm2, ymm3; - v4df_t ymm4, ymm5, ymm6, ymm7; - v4df_t ymm8, ymm10; - v4df_t ymm12, ymm13; - - double* ap[4]; - double* y0 = (double*)y; - - dcomplex chi0; - dcomplex chi1; - dcomplex chi2; - dcomplex chi3; - - dim_t setPlusOne = 1; - - if ( bli_is_conj(conja) ) - { - setPlusOne = -1; - } // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; // If b_n is not equal to the fusing factor, then perform the entire // operation as a loop over axpyv. - if ( b_n != fuse_fac ) + if (b_n != fuse_fac) { - zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + __m128d x_vec, alpha_real, alpha_imag, temp[2]; - for ( i = 0; i < b_n; ++i ) + alpha_real = _mm_set1_pd(((*alpha).real)); + alpha_imag = _mm_set1_pd(((*alpha).imag)); + + for (dim_t i = 0; i < b_n; ++i) { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; + dcomplex *a1 = a + (0) * inca + (i)*lda; + dcomplex *chi1 = x + (i)*incx; + dcomplex *y1 = y + (0) * incy; + dcomplex alpha_chi1; - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); + // Vectorization of scaling X by alpha + x_vec = _mm_loadu_pd((double *)chi1); - f + if (bli_is_conj(conjx)) + { + __m128d identity; + + identity = _mm_setr_pd(1, -1); + + x_vec = _mm_mul_pd(x_vec, identity); + } + + temp[0] = _mm_mul_pd(x_vec, alpha_real); + temp[1] = _mm_mul_pd(x_vec, alpha_imag); + + temp[1] = _mm_permute_pd(temp[1], 0b01); + + temp[0] = _mm_addsub_pd(temp[0], temp[1]); + + _mm_storeu_pd((double *)&alpha_chi1, temp[0]); + + bli_zaxpyv_zen_int5 ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx ); } return; } + // A prefetch distance used inside the main loop + const dim_t distance = 32; - // At this point, we know that b_n is exactly equal to the fusing factor. - if(bli_is_noconj(conjx)) + dcomplex chi0 = *(x + 0 * incx); + dcomplex chi1 = *(x + 1 * incx); + dcomplex chi2 = *(x + 2 * incx); + dcomplex chi3 = *(x + 3 * incx); + + /* Alpha scaling of X can be vectorized + irrespective of the incx and should + be avoided when alpha is 1*/ + __m128d x_vec[8], alpha_real, alpha_imag, temp[8]; + + x_vec[0] = _mm_loadu_pd((double *)&chi0); + x_vec[1] = _mm_loadu_pd((double *)&chi1); + x_vec[2] = _mm_loadu_pd((double *)&chi2); + x_vec[3] = _mm_loadu_pd((double *)&chi3); + + if (bli_is_conj(conjx)) { - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); + __m128d identity; + + identity = _mm_setr_pd(1, -1); + + x_vec[0] = _mm_mul_pd(x_vec[0], identity); + x_vec[1] = _mm_mul_pd(x_vec[1], identity); + x_vec[2] = _mm_mul_pd(x_vec[2], identity); + x_vec[3] = _mm_mul_pd(x_vec[3], identity); + } + + if (!(bli_zeq1(*alpha))) + { + alpha_real = _mm_set1_pd(((*alpha).real)); + alpha_imag = _mm_set1_pd(((*alpha).imag)); + + temp[0] = _mm_mul_pd(x_vec[0], alpha_real); + temp[1] = _mm_mul_pd(x_vec[0], alpha_imag); + temp[2] = _mm_mul_pd(x_vec[1], alpha_real); + temp[3] = _mm_mul_pd(x_vec[1], alpha_imag); + temp[4] = _mm_mul_pd(x_vec[2], alpha_real); + temp[5] = _mm_mul_pd(x_vec[2], alpha_imag); + temp[6] = _mm_mul_pd(x_vec[3], alpha_real); + temp[7] = _mm_mul_pd(x_vec[3], alpha_imag); + + temp[1] = _mm_permute_pd(temp[1], 0b01); + temp[3] = _mm_permute_pd(temp[3], 0b01); + temp[5] = _mm_permute_pd(temp[5], 0b01); + temp[7] = _mm_permute_pd(temp[7], 0b01); + + temp[0] = _mm_addsub_pd(temp[0], temp[1]); + temp[2] = _mm_addsub_pd(temp[2], temp[3]); + temp[4] = _mm_addsub_pd(temp[4], temp[5]); + temp[6] = _mm_addsub_pd(temp[6], temp[7]); + + _mm_storeu_pd((double *)&chi0, temp[0]); + _mm_storeu_pd((double *)&chi1, temp[2]); + _mm_storeu_pd((double *)&chi2, temp[4]); + _mm_storeu_pd((double *)&chi3, temp[6]); } else { - dcomplex *pchi0 = x + 0*incx ; - dcomplex *pchi1 = x + 1*incx ; - dcomplex *pchi2 = x + 2*incx ; - dcomplex *pchi3 = x + 3*incx ; - - bli_zcopycjs( conjx, *pchi0, chi0 ); - bli_zcopycjs( conjx, *pchi1, chi1 ); - bli_zcopycjs( conjx, *pchi2, chi2 ); - bli_zcopycjs( conjx, *pchi3, chi3 ); + _mm_storeu_pd((double *)&chi0, x_vec[0]); + _mm_storeu_pd((double *)&chi1, x_vec[1]); + _mm_storeu_pd((double *)&chi2, x_vec[2]); + _mm_storeu_pd((double *)&chi3, x_vec[3]); } - // Scale each chi scalar by alpha. - bli_zscals( *alpha, chi0 ); - bli_zscals( *alpha, chi1 ); - bli_zscals( *alpha, chi2 ); - bli_zscals( *alpha, chi3 ); + dim_t i = 0; - lda *= 2; - incx *= 2; - incy *= 2; - inca *= 2; + double *a_ptr[4]; + double *y0 = (double *)y; - ap[0] = (double*)a; - ap[1] = (double*)a + lda; - ap[2] = ap[1] + lda; - ap[3] = ap[2] + lda; + a_ptr[0] = (double *)a; + a_ptr[1] = (double *)a + 2 * lda; + a_ptr[2] = a_ptr[1] + 2 * lda; + a_ptr[3] = a_ptr[2] + 2 * lda; - if( inca == 2 && incy == 2 ) + + // Prefetching the elements of A to the L1 cache. + // These will be used even if SSE instructions are used + _mm_prefetch(a_ptr[0], _MM_HINT_T1); + _mm_prefetch(a_ptr[1], _MM_HINT_T1); + _mm_prefetch(a_ptr[2], _MM_HINT_T1); + _mm_prefetch(a_ptr[3], _MM_HINT_T1); + + if (inca == 1 && incy == 1) { - inc_t n1 = m >> 1; // Divide by 2 - inc_t n2 = m & 1; // % 2 - ymm12.v = _mm256_setzero_pd(); - ymm13.v = _mm256_setzero_pd(); + v4df_t ymm0, ymm1, ymm2, ymm3; + v4df_t ymm4, ymm5, ymm6, ymm7; + v4df_t ymm8, ymm10; + v4df_t ymm12, ymm13, ymm14, ymm15; // broadcast real & imag parts of 4 elements of x ymm0.v = _mm256_broadcast_sd(&chi0.real); // real part of x0 @@ -415,114 +463,358 @@ void bli_zaxpyf_zen_int_4 ymm6.v = _mm256_broadcast_sd(&chi3.real); // real part of x3 ymm7.v = _mm256_broadcast_sd(&chi3.imag); // imag part of x3 - - for(i = 0; i < n1; i++) + if (bli_is_noconj(conja)) { - //load first two columns of A - ymm8.v = _mm256_loadu_pd(ap[0] + 0); // 2 complex values form a0 - ymm10.v = _mm256_loadu_pd(ap[1] + 0); // 2 complex values form a0 - ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); - ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); - - ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); - ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); - - //load 3rd and 4th columns of A - ymm8.v = _mm256_loadu_pd(ap[2] + 0); - ymm10.v = _mm256_loadu_pd(ap[3] + 0); - - ymm12.v = _mm256_fmadd_pd(ymm8.v, ymm4.v, ymm12.v); - ymm13.v = _mm256_fmadd_pd(ymm8.v, ymm5.v, ymm13.v); - - ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm6.v, ymm12.v); - ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm7.v, ymm13.v); - - //load Y vector - ymm10.v = _mm256_loadu_pd(y0 + 0); - - if(bli_is_noconj(conja)) + for (; (i + 3) < m; i += 4) { + // load first two columns of A + ymm8.v = _mm256_loadu_pd(a_ptr[0]); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1]); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2]); + ymm15.v = _mm256_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by X + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + _mm_prefetch(a_ptr[0] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[1] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[2] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[3] + distance, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + _mm_prefetch(y0 + distance, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0); + + // Permute and reduce the complex and real parts ymm13.v = _mm256_permute_pd(ymm13.v, 5); ymm8.v = _mm256_addsub_pd(ymm12.v, ymm13.v); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0), ymm12.v); + + // load first two columns of A + ymm8.v = _mm256_loadu_pd(a_ptr[0] + 4); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1] + 4); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2] + 4); + ymm15.v = _mm256_loadu_pd(a_ptr[3] + 4); + + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + _mm_prefetch(a_ptr[0] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[1] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[2] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[3] + distance * 2, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + _mm_prefetch(y0 + distance * 2, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0 + 4); + + ymm13.v = _mm256_permute_pd(ymm13.v, 5); + ymm8.v = _mm256_addsub_pd(ymm12.v, ymm13.v); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0 + 4), ymm12.v); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; } - else + + for (; (i + 1) < m; i += 2) { + // load first two columns of A + ymm8.v = _mm256_loadu_pd(a_ptr[0]); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1]); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2]); + ymm15.v = _mm256_loadu_pd(a_ptr[3]); + + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0); + + ymm13.v = _mm256_permute_pd(ymm13.v, 5); + ymm8.v = _mm256_addsub_pd(ymm12.v, ymm13.v); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0), ymm12.v); + + y0 += 4; + a_ptr[0] += 4; + a_ptr[1] += 4; + a_ptr[2] += 4; + a_ptr[3] += 4; + } + } + else + { + + for (; (i + 3) < m; i += 4) + { + // load first two columns of A + ymm8.v = _mm256_loadu_pd(a_ptr[0]); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1]); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2]); + ymm15.v = _mm256_loadu_pd(a_ptr[3]); + + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + _mm_prefetch(a_ptr[0] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[1] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[2] + distance, _MM_HINT_T1); + _mm_prefetch(a_ptr[3] + distance, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + _mm_prefetch(y0 + distance, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0); + ymm12.v = _mm256_permute_pd(ymm12.v, 5); ymm8.v = _mm256_addsub_pd(ymm13.v, ymm12.v); ymm8.v = _mm256_permute_pd(ymm8.v, 5); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0), ymm12.v); + + ymm8.v = _mm256_loadu_pd(a_ptr[0] + 4); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1] + 4); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2] + 4); + ymm15.v = _mm256_loadu_pd(a_ptr[3] + 4); + + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + _mm_prefetch(a_ptr[0] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[1] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[2] + distance * 2, _MM_HINT_T1); + _mm_prefetch(a_ptr[3] + distance * 2, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + _mm_prefetch(y0 + distance * 2, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0 + 4); + + ymm12.v = _mm256_permute_pd(ymm12.v, 5); + ymm8.v = _mm256_addsub_pd(ymm13.v, ymm12.v); + ymm8.v = _mm256_permute_pd(ymm8.v, 5); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0 + 4), ymm12.v); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; } - ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + for (; (i + 1) < m; i += 2) + { + // load first two columns of A + ymm8.v = _mm256_loadu_pd(a_ptr[0]); // 2 complex values from a0 + ymm10.v = _mm256_loadu_pd(a_ptr[1]); // 2 complex values from a0 + // load 3rd and 4th columns of A + ymm14.v = _mm256_loadu_pd(a_ptr[2]); + ymm15.v = _mm256_loadu_pd(a_ptr[3]); - _mm256_storeu_pd((double*)(y0), ymm12.v); + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); - y0 += 4; - ap[0] += 4; - ap[1] += 4; - ap[2] += 4; - ap[3] += 4; + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + //_mm_prefetch(y0, _MM_HINT_T1); + + ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); + + ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); + + // load Y vector + ymm10.v = _mm256_loadu_pd(y0); + + ymm12.v = _mm256_permute_pd(ymm12.v, 5); + ymm8.v = _mm256_addsub_pd(ymm13.v, ymm12.v); + ymm8.v = _mm256_permute_pd(ymm8.v, 5); + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double *)(y0), ymm12.v); + + y0 += 4; + a_ptr[0] += 4; + a_ptr[1] += 4; + a_ptr[2] += 4; + a_ptr[3] += 4; + } } + } - // If there are leftover iterations, perform them with scalar code. + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); - for ( i = 0; (i + 0) < n2 ; ++i ) + __m128d a_vec[4], y_vec, inter[2]; + + // broadcast real & imag parts of 4 elements of x + x_vec[0] = _mm_set1_pd(chi0.real); // real part of x0 + x_vec[1] = _mm_set1_pd(chi0.imag); // imag part of x0 + x_vec[2] = _mm_set1_pd(chi1.real); // real part of x1 + x_vec[3] = _mm_set1_pd(chi1.imag); // imag part of x1 + x_vec[4] = _mm_set1_pd(chi2.real); // real part of x2 + x_vec[5] = _mm_set1_pd(chi2.imag); // imag part of x2 + x_vec[6] = _mm_set1_pd(chi3.real); // real part of x3 + x_vec[7] = _mm_set1_pd(chi3.imag); // imag part of x3 + + if (bli_is_noconj(conja)) + { + for (; i < m; i++) { - dcomplex y0c = *(dcomplex*)y0; + // load first two columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); // 2 complex values from a0 + a_vec[1] = _mm_loadu_pd(a_ptr[1]); // 2 complex values from a0 + a_vec[2] = _mm_loadu_pd(a_ptr[2]); // 2 complex values from a0 + a_vec[3] = _mm_loadu_pd(a_ptr[3]); // 2 complex values from a0 - const dcomplex a0c = *(dcomplex*)ap[0]; - const dcomplex a1c = *(dcomplex*)ap[1]; - const dcomplex a2c = *(dcomplex*)ap[2]; - const dcomplex a3c = *(dcomplex*)ap[3]; + inter[0] = _mm_mul_pd(a_vec[0], x_vec[0]); + inter[1] = _mm_mul_pd(a_vec[0], x_vec[1]); - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + inter[0] = _mm_fmadd_pd(a_vec[1], x_vec[2], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[1], x_vec[3], inter[1]); - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + //_mm_prefetch(y0, _MM_HINT_T1); - *(dcomplex*)y0 = y0c; + inter[0] = _mm_fmadd_pd(a_vec[2], x_vec[4], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[2], x_vec[5], inter[1]); - ap[0] += 2; - ap[1] += 2; - ap[2] += 2; - ap[3] += 2; - y0 += 2; + inter[0] = _mm_fmadd_pd(a_vec[3], x_vec[6], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[3], x_vec[7], inter[1]); + + inter[1] = _mm_permute_pd(inter[1], 0b01); + inter[0] = _mm_addsub_pd(inter[0], inter[1]); + + // load Y vector + y_vec = _mm_loadu_pd(y0); + + y_vec = _mm_add_pd(y_vec, inter[0]); + + _mm_storeu_pd((double *)(y0), y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; } - //PASTEMAC(c,fprintm)(stdout, "Y after A*x in axpyf",m, 1, (scomplex*)y, 1, 1, "%4.1f", ""); - } else { - for (i = 0 ; (i + 0) < m ; ++i ) + for (; i < m; i++) { - dcomplex y0c = *(dcomplex*)y0; - const dcomplex a0c = *(dcomplex*)ap[0]; - const dcomplex a1c = *(dcomplex*)ap[1]; - const dcomplex a2c = *(dcomplex*)ap[2]; - const dcomplex a3c = *(dcomplex*)ap[3]; + // load first two columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); // 2 complex values from a0 + a_vec[1] = _mm_loadu_pd(a_ptr[1]); // 2 complex values from a0 + // load 3rd and 4th columns of A + a_vec[2] = _mm_loadu_pd(a_ptr[2]); // 2 complex values from a0 + a_vec[3] = _mm_loadu_pd(a_ptr[3]); // 2 complex values from a0 - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + inter[0] = _mm_mul_pd(a_vec[0], x_vec[0]); + inter[1] = _mm_mul_pd(a_vec[0], x_vec[1]); - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + inter[0] = _mm_fmadd_pd(a_vec[1], x_vec[2], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[1], x_vec[3], inter[1]); - *(dcomplex*)y0 = y0c; + // load Y vector + y_vec = _mm_loadu_pd(y0); - ap[0] += inca; - ap[1] += inca; - ap[2] += inca; - ap[3] += inca; - y0 += incy; + inter[0] = _mm_fmadd_pd(a_vec[2], x_vec[4], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[2], x_vec[5], inter[1]); + + inter[0] = _mm_fmadd_pd(a_vec[3], x_vec[6], inter[0]); + inter[1] = _mm_fmadd_pd(a_vec[3], x_vec[7], inter[1]); + + inter[0] = _mm_permute_pd(inter[0], 0b01); + inter[0] = _mm_addsub_pd(inter[1], inter[0]); + inter[0] = _mm_permute_pd(inter[0], 0b01); + + y_vec = _mm_add_pd(y_vec, inter[0]); + + _mm_storeu_pd((double *)(y0), y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; } } + + // vzeroupper is added by the compiler at the end of the kernel }