mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Improve numerical precision in ZGEMV API (#130)
- Replaced separate real and imaginary accumulators (real_acc, imag_acc) with a column-wise accumulator array (row_acc[2]), making accumulation and updates to the target Y vector more direct, concise, and unified. - Leveraged AVX-512 fused multiply-add/subtract operations (_mm512_fmaddsub_pd, _mm512_fmsubadd_pd) and efficient permutations (_mm512_permute_pd) to enable accurate and efficient computation of real and imaginary components in a single instruction, while reducing code complexity for both code paths. - Removed redundant instructions (such as unnecessary permutations and zero-register operations) and simplified the control flow. AMD-Internal: [CPUPL-7015]
This commit is contained in:
@@ -443,7 +443,24 @@ void bli_daxpyf_zen_int32_avx512_mt
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* bli_zaxpyf_zen_int_2_avx512
|
||||
*
|
||||
* Optimized AVX-512 kernel for the complex double-precision AXPYF operation
|
||||
* with a fusing factor of 2. Computes:
|
||||
* y := y + alpha * A * x
|
||||
* where:
|
||||
* - y is an m-dimensional complex vector,
|
||||
* - x is a 2-dimensional complex vector,
|
||||
* - A is an m x 2 complex matrix,
|
||||
* - alpha is a complex scalar.
|
||||
*
|
||||
* This kernel handles both conjugated and non-conjugated cases for A and x,
|
||||
* and uses AVX-512 vectorization for high performance when strides are unit.
|
||||
* For non-unit strides, it falls back to 128-bit vectorized code.
|
||||
* Handles edge cases for alpha == 0, alpha == 1, and alpha == -1.
|
||||
* For b_n != 2, dispatches to smaller kernels or zaxpyv as needed.
|
||||
*/
|
||||
void bli_zaxpyf_zen_int_2_avx512
|
||||
(
|
||||
conj_t conja,
|
||||
@@ -594,12 +611,23 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
|
||||
// Registers to load A, accumulate real and imag scaling separately
|
||||
__m512d a_vec[2];
|
||||
__m512d real_acc, imag_acc, y_vec;
|
||||
__m512d zero_reg = _mm512_setzero_pd();
|
||||
__m512d y_vec;
|
||||
__m512d row_acc[2];
|
||||
|
||||
// Execute the loops is m >= 4(AVX-512 unmasked code-section)
|
||||
if( m >= 4 )
|
||||
{
|
||||
{
|
||||
/*
|
||||
For each column of A:
|
||||
Multiply A by the imaginary part of alpha*X.
|
||||
Permute result to swap real/imag (0x55 mask).
|
||||
Use fused multiply-add/subtract to combine real part.
|
||||
Accumulate results from both columns.
|
||||
Add result to Y.
|
||||
|
||||
This leverages AVX-512 fused operations to efficiently compute:
|
||||
Y += A * (alpha * X)
|
||||
*/
|
||||
if ( bli_is_noconj(conja) )
|
||||
{
|
||||
for (; (i + 7) < m; i += 8)
|
||||
@@ -607,22 +635,31 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
// Load first 4 elements from first 4 columns of A
|
||||
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
|
||||
|
||||
// Swaps real and imaginary components in each complex pair.
|
||||
// Used for correct handling of complex arithmetic.
|
||||
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
|
||||
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
|
||||
|
||||
// Load first 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0, y_vec);
|
||||
@@ -632,21 +669,30 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
|
||||
|
||||
// Swaps real and imaginary components in each complex pair.
|
||||
// Used for correct handling of complex arithmetic.
|
||||
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
|
||||
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
|
||||
// Load next 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0 + 8);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0 + 8, y_vec);
|
||||
@@ -663,21 +709,30 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
|
||||
|
||||
// Swaps real and imaginary components in each complex pair.
|
||||
// Used for correct handling of complex arithmetic.
|
||||
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
|
||||
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
|
||||
// Load first 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0, y_vec);
|
||||
@@ -689,28 +744,48 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
}
|
||||
else
|
||||
{
|
||||
/*
|
||||
For the conjugate case:
|
||||
Permute each column of A to swap real/imag (for conjugation).
|
||||
Multiply by real part of alpha*X.
|
||||
Use fused multiply-sub/add with permuted imaginary part.
|
||||
Accumulate results from both columns.
|
||||
Add result to Y.
|
||||
|
||||
This implements conjugated complex multiplication and accumulation using AVX-512.
|
||||
*/
|
||||
__m512d a_vec_shuf[2];
|
||||
|
||||
for (; (i + 7) < m; i += 8)
|
||||
{
|
||||
// Load first 4 elements from first 4 columns of A
|
||||
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// Permute the loaded columns of A to swap real and imaginary parts.
|
||||
// This is needed for correct complex conjugate computation.
|
||||
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
|
||||
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
|
||||
|
||||
// Load first 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0, y_vec);
|
||||
@@ -719,22 +794,30 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8);
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// Permute the loaded columns of A to swap real and imaginary parts.
|
||||
// This is needed for correct complex conjugate computation.
|
||||
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
|
||||
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
|
||||
|
||||
// Load next 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0 + 8);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0 + 8, y_vec);
|
||||
@@ -750,22 +833,30 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
|
||||
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// Permute the loaded columns of A to swap real and imaginary parts.
|
||||
// This is needed for correct complex conjugate computation.
|
||||
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
|
||||
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
|
||||
|
||||
// Load first 4 elements of Y vector
|
||||
y_vec = _mm512_loadu_pd(y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_storeu_pd(y0, y_vec);
|
||||
@@ -776,8 +867,14 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if( i < m )
|
||||
{
|
||||
/*
|
||||
For the tail case (where m is not a multiple of 8):
|
||||
Use AVX-512 masked loads/stores to process the remaining elements safely.
|
||||
m_mask ensures only valid elements are read/written.
|
||||
*/
|
||||
__mmask8 m_mask = (1 << 2*(m - i)) - 1;
|
||||
if( bli_is_noconj(conja) )
|
||||
{
|
||||
@@ -786,47 +883,65 @@ void bli_zaxpyf_zen_int_2_avx512
|
||||
a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Swaps real and imaginary components in each complex pair.
|
||||
// Used for correct handling of complex arithmetic.
|
||||
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
|
||||
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
|
||||
|
||||
// Load remaining elements of Y vector
|
||||
y_vec = _mm512_maskz_loadu_pd(m_mask, y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_mask_storeu_pd(y0, m_mask, y_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
__m512d a_vec_shuf[2];
|
||||
|
||||
// Load remaining elements from first 4 columns of A
|
||||
a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]);
|
||||
a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]);
|
||||
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
|
||||
// Permute the loaded columns of A to swap real and imaginary parts.
|
||||
// This is needed for correct complex conjugate computation.
|
||||
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
|
||||
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
|
||||
|
||||
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
|
||||
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
|
||||
// Multiply the loaded columns of A by alpha*X(real and imag)
|
||||
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
|
||||
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
|
||||
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
|
||||
|
||||
// Perform fused multiply-add/subtract for complex arithmetic:
|
||||
// In non-conjugate: real + i*imag
|
||||
// In conjugate: real - i*imag
|
||||
// These allow both real and imaginary parts to be computed in one instruction.
|
||||
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
|
||||
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
|
||||
|
||||
// Load remaining elements of Y vector
|
||||
y_vec = _mm512_maskz_loadu_pd(m_mask, y0);
|
||||
|
||||
// Permute and reduce the complex and real parts
|
||||
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
|
||||
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
|
||||
real_acc = _mm512_add_pd(real_acc, imag_acc);
|
||||
// Accumulate the results from both columns into row_acc[0]
|
||||
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
|
||||
|
||||
y_vec = _mm512_add_pd(y_vec, real_acc);
|
||||
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
|
||||
|
||||
// Store onto Y vector
|
||||
_mm512_mask_storeu_pd(y0, m_mask, y_vec);
|
||||
|
||||
Reference in New Issue
Block a user