Improve numerical precision in ZGEMV API (#130)

- Replaced separate real and imaginary accumulators (real_acc, imag_acc) with a column-wise accumulator array (row_acc[2]), making accumulation and updates to the target Y vector more direct, concise, and unified.

- Leveraged AVX-512 fused multiply-add/subtract operations (_mm512_fmaddsub_pd, _mm512_fmsubadd_pd) and efficient permutations (_mm512_permute_pd) to enable accurate and efficient computation of real and imaginary components in a single instruction, while reducing code complexity for both code paths.

- Removed redundant instructions (such as unnecessary permutations and zero-register operations) and simplified the control flow.

AMD-Internal: [CPUPL-7015]
This commit is contained in:
S, Hari Govind
2025-08-14 11:19:51 +05:30
committed by GitHub
parent fa69528a3b
commit 9a7bacb30c

View File

@@ -443,7 +443,24 @@ void bli_daxpyf_zen_int32_avx512_mt
}
#endif
/**
* bli_zaxpyf_zen_int_2_avx512
*
* Optimized AVX-512 kernel for the complex double-precision AXPYF operation
* with a fusing factor of 2. Computes:
* y := y + alpha * A * x
* where:
* - y is an m-dimensional complex vector,
* - x is a 2-dimensional complex vector,
* - A is an m x 2 complex matrix,
* - alpha is a complex scalar.
*
* This kernel handles both conjugated and non-conjugated cases for A and x,
* and uses AVX-512 vectorization for high performance when strides are unit.
* For non-unit strides, it falls back to 128-bit vectorized code.
* Handles edge cases for alpha == 0, alpha == 1, and alpha == -1.
* For b_n != 2, dispatches to smaller kernels or zaxpyv as needed.
*/
void bli_zaxpyf_zen_int_2_avx512
(
conj_t conja,
@@ -594,12 +611,23 @@ void bli_zaxpyf_zen_int_2_avx512
// Registers to load A, accumulate real and imag scaling separately
__m512d a_vec[2];
__m512d real_acc, imag_acc, y_vec;
__m512d zero_reg = _mm512_setzero_pd();
__m512d y_vec;
__m512d row_acc[2];
// Execute the loops is m >= 4(AVX-512 unmasked code-section)
if( m >= 4 )
{
{
/*
For each column of A:
Multiply A by the imaginary part of alpha*X.
Permute result to swap real/imag (0x55 mask).
Use fused multiply-add/subtract to combine real part.
Accumulate results from both columns.
Add result to Y.
This leverages AVX-512 fused operations to efficiently compute:
Y += A * (alpha * X)
*/
if ( bli_is_noconj(conja) )
{
for (; (i + 7) < m; i += 8)
@@ -607,22 +635,31 @@ void bli_zaxpyf_zen_int_2_avx512
// Load first 4 elements from first 4 columns of A
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Multiply the loaded columns of A by alpha*X(real and imag)
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
// Swaps real and imaginary components in each complex pair.
// Used for correct handling of complex arithmetic.
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
// Load first 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0, y_vec);
@@ -632,21 +669,30 @@ void bli_zaxpyf_zen_int_2_avx512
a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
// Swaps real and imaginary components in each complex pair.
// Used for correct handling of complex arithmetic.
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Load next 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0 + 8);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0 + 8, y_vec);
@@ -663,21 +709,30 @@ void bli_zaxpyf_zen_int_2_avx512
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
// Swaps real and imaginary components in each complex pair.
// Used for correct handling of complex arithmetic.
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Load first 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0, y_vec);
@@ -689,28 +744,48 @@ void bli_zaxpyf_zen_int_2_avx512
}
else
{
/*
For the conjugate case:
Permute each column of A to swap real/imag (for conjugation).
Multiply by real part of alpha*X.
Use fused multiply-sub/add with permuted imaginary part.
Accumulate results from both columns.
Add result to Y.
This implements conjugated complex multiplication and accumulation using AVX-512.
*/
__m512d a_vec_shuf[2];
for (; (i + 7) < m; i += 8)
{
// Load first 4 elements from first 4 columns of A
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// Permute the loaded columns of A to swap real and imaginary parts.
// This is needed for correct complex conjugate computation.
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Multiply the loaded columns of A by alpha*X(real and imag)
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
// Load first 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0, y_vec);
@@ -719,22 +794,30 @@ void bli_zaxpyf_zen_int_2_avx512
a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8);
a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// Permute the loaded columns of A to swap real and imaginary parts.
// This is needed for correct complex conjugate computation.
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Multiply the loaded columns of A by alpha*X(real and imag)
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
// Load next 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0 + 8);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0 + 8, y_vec);
@@ -750,22 +833,30 @@ void bli_zaxpyf_zen_int_2_avx512
a_vec[0] = _mm512_loadu_pd(a_ptr[0]);
a_vec[1] = _mm512_loadu_pd(a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// Permute the loaded columns of A to swap real and imaginary parts.
// This is needed for correct complex conjugate computation.
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Multiply the loaded columns of A by alpha*X(real and imag)
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
// Load first 4 elements of Y vector
y_vec = _mm512_loadu_pd(y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_storeu_pd(y0, y_vec);
@@ -776,8 +867,14 @@ void bli_zaxpyf_zen_int_2_avx512
}
}
}
if( i < m )
{
/*
For the tail case (where m is not a multiple of 8):
Use AVX-512 masked loads/stores to process the remaining elements safely.
m_mask ensures only valid elements are read/written.
*/
__mmask8 m_mask = (1 << 2*(m - i)) - 1;
if( bli_is_noconj(conja) )
{
@@ -786,47 +883,65 @@ void bli_zaxpyf_zen_int_2_avx512
a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// For each column, first multiply by the imaginary part, permute, then fused multiply-add/sub with the real part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_imag[1]);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Swaps real and imaginary components in each complex pair.
// Used for correct handling of complex arithmetic.
row_acc[0] = _mm512_permute_pd(row_acc[0], 0x55);
row_acc[1] = _mm512_permute_pd(row_acc[1], 0x55);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmaddsub_pd(a_vec[0], alpha_x_real[0], row_acc[0]);
row_acc[1] = _mm512_fmaddsub_pd(a_vec[1], alpha_x_real[1], row_acc[1]);
// Load remaining elements of Y vector
y_vec = _mm512_maskz_loadu_pd(m_mask, y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_mask_storeu_pd(y0, m_mask, y_vec);
}
else
{
__m512d a_vec_shuf[2];
// Load remaining elements from first 4 columns of A
a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]);
a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]);
// Multiply the loaded columns of A by alpha*X(real and imag)
real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]);
// Permute the loaded columns of A to swap real and imaginary parts.
// This is needed for correct complex conjugate computation.
a_vec_shuf[0] = _mm512_permute_pd(a_vec[0], 0x55);
a_vec_shuf[1] = _mm512_permute_pd(a_vec[1], 0x55);
real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc);
imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc);
// Multiply the loaded columns of A by alpha*X(real and imag)
// For each column, first multiply by the real part, then fused multiply-sub/add with the imaginary part.
row_acc[0] = _mm512_mul_pd(a_vec[0], alpha_x_real[0]);
row_acc[1] = _mm512_mul_pd(a_vec[1], alpha_x_real[1]);
// Perform fused multiply-add/subtract for complex arithmetic:
// In non-conjugate: real + i*imag
// In conjugate: real - i*imag
// These allow both real and imaginary parts to be computed in one instruction.
row_acc[0] = _mm512_fmsubadd_pd(a_vec_shuf[0], alpha_x_imag[0], row_acc[0]);
row_acc[1] = _mm512_fmsubadd_pd(a_vec_shuf[1], alpha_x_imag[1], row_acc[1]);
// Load remaining elements of Y vector
y_vec = _mm512_maskz_loadu_pd(m_mask, y0);
// Permute and reduce the complex and real parts
imag_acc = _mm512_permute_pd(imag_acc, 0x55);
real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc);
real_acc = _mm512_add_pd(real_acc, imag_acc);
// Accumulate the results from both columns into row_acc[0]
row_acc[0] = _mm512_add_pd(row_acc[0], row_acc[1]);
y_vec = _mm512_add_pd(y_vec, real_acc);
y_vec = _mm512_add_pd(y_vec, row_acc[0]);
// Store onto Y vector
_mm512_mask_storeu_pd(y0, m_mask, y_vec);